amica-python 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
amica/linalg.py ADDED
@@ -0,0 +1,349 @@
1
+ """Whitening, unmixing helpers, determinant computation, etc."""
2
+
3
+ from typing import Literal
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from amica._types import ComponentsVector, DataArray2D, WeightsArray
9
+
10
+ from .utils._logging import log
11
+
12
+
13
+ def get_unmixing_matrices(
14
+ *,
15
+ c,
16
+ A,
17
+ W,
18
+ ):
19
+ """Get unmixing matrices for AMICA."""
20
+ wc = torch.zeros_like(c)
21
+ #--------------------------FORTRAN CODE-------------------------
22
+ # call DCOPY(nw*nw,A(:,comp_list(:,h)),1,W(:,:,h),1)
23
+ #---------------------------------------------------------------
24
+ W[:, :] = A[:, :].clone()
25
+
26
+ #--------------------------FORTRAN CODE-------------------------
27
+ # call DGETRF(nw,nw,W(:,:,h),nw,ipivnw,info)
28
+ # call DGETRI(nw,W(:,:,h),nw,ipivnw,work,lwork,info)
29
+ #---------------------------------------------------------------
30
+ try:
31
+ W[:, :] = torch.linalg.inv(W[:, :])
32
+ except RuntimeError as e:
33
+ log("Matrix W is singular!")
34
+ raise e
35
+
36
+ #--------------------------FORTRAN CODE-------------------------
37
+ # call DGEMV('N',nw,nw,dble(1.0),W(:,:,h),nw,c(:,h),1,dble(0.0),wc(:,h),1)
38
+ #---------------------------------------------------------------
39
+ wc[:] = W[:, :] @ c[:]
40
+
41
+ return W, wc
42
+
43
+
44
+ def compute_sign_log_determinant(
45
+ *,
46
+ unmixing_matrix: WeightsArray,
47
+ minlog: float = -1500,
48
+ mode: Literal["strict", "fallback"] = "strict",
49
+ ) -> tuple[Literal[-1, 1], float]:
50
+ """Compute the sign and log-determinant of the unmixing matrix for a single model.
51
+
52
+ Parameters
53
+ ----------
54
+ unmixing_matrix: array, shape (n_components, n_features)
55
+ The unmixing matrix W for a single model h (i.e. a 2D slice of state.W).
56
+ minlog: float
57
+ Minimum log value: log absolute determinant to return if the computed
58
+ log-determinant is zero. Default is -1500, but currently if the computed
59
+ log-determinant is zero, an error is raised instead.
60
+ mode: str
61
+ Mode for handling cases where the computed log-determinant is zero.
62
+ default is "strict", Options are:
63
+ - "strict": Raise a ValueError if the log-determinant is zero.
64
+ - "fallback": Issue a warning and then set the log-determinant to minlog
65
+ (default: -1500) and set sign to -1.
66
+
67
+ Returns
68
+ -------
69
+ sign: {-1, 1}
70
+ The sign of the determinant (+1 or -1). In "fallback" mode, sign is set to -1 if
71
+ the determinant is zero, to maintain the invariant that sign (never 0).
72
+ logabsdet: float
73
+ The (natural) log-determinant of the unmixing matrix. In "fallback" mode, this
74
+ is set to minlog (default: -1500).
75
+ """
76
+ #--------------------------------FORTRAN CODE------------------------------
77
+ # call DCOPY(nw*nw,W(:,:,h),1,Wtmp,1)
78
+ # ....
79
+ # call DGEQRF(nw,nw,Wtmp,nw,wr,work,lwork,info)
80
+ # ...
81
+ # Dtemp(h) = dble(0.0)
82
+ # ...
83
+ # Dtemp(h) = Dtemp(h) + log(abs(Wtmp(i,i)))
84
+ # ------------------------------------------------------------------------
85
+ # Alias for clarity with Fortran code
86
+ W = unmixing_matrix
87
+ sign, logabsdet = torch.linalg.slogdet(W)
88
+ # TODO: slogdet requires a square unmixing matrix. Does AMICA gaurantee this?
89
+ if logabsdet == -torch.inf or sign == 0: # Model fit has collapsed.
90
+ msg = (
91
+ "Unmixing matrix (W) is singular (determinant = 0)\n\n"
92
+ "Details:\n"
93
+ f"- shape of W: {W.shape}\n"
94
+ f"- sign={sign}, log|det|={logabsdet}\n\n"
95
+ "Things to try:\n"
96
+ "- Check that your input data is rank-sufficient\n"
97
+ "- Reduce the number of components\n"
98
+ )
99
+ if mode == "strict":
100
+ # By default Let's raise an error until we can test this case properly
101
+ raise ValueError(msg)
102
+ else:
103
+ log(msg, level="warning")
104
+ log(f"Setting log-determinant to {minlog} and sign to -1", level="warning")
105
+ # fallback values (numerical hack to let training continue)
106
+ logabsdet = minlog
107
+ sign = -1 # matches dsign = 1 if det > 0 else -1 in Fortran
108
+ return sign, logabsdet
109
+
110
+
111
+ def get_initial_model_log_likelihood(
112
+ *,
113
+ unmixing_logdet: float,
114
+ whitening_logdet: float,
115
+ model_weight: float,
116
+ ) -> float:
117
+ """
118
+ Initialize the per-sample model log-likelihood with baseline terms.
119
+
120
+ Parameters
121
+ ----------
122
+ unmixing_logdet : float
123
+ The log-determinant of the unmixing matrix (W) for this model.
124
+ whitening_logdet : float
125
+ The log-determinant of the sphering/whitening transform (S),
126
+ computed from the input data's whitening/sphering matrix. e.g. It's computed
127
+ as -0.5 ∑ log(λ_i) where λ_i are covariance eigenvalues.
128
+ model_weight : float
129
+ The mixture proportion (prior probability) for this model.
130
+
131
+ Returns
132
+ -------
133
+ initial_modloglik : float
134
+ A scalar baseline log-likelihood value. This should be broadcast across all
135
+ samples of the model log-likelihood array at the call site.
136
+
137
+ Notes
138
+ -----
139
+ - The Jacobian from x → u is |det(W S)|, so log|det(W S)| = log|det(W)| +
140
+ log|det(S)| = Dsum[h] + sldet.
141
+ - S is positive-definite with full-rank whitening, so sldet has no sign
142
+ issue
143
+ - In the Fortran code, the variable Ptmp(bstrt:bstp,h) holds the initial
144
+ model log-likelihood for model h across the data block (bstrt:bstp). This gets
145
+ copied into modloglik.
146
+ """
147
+ whitening_logdet = torch.as_tensor(whitening_logdet, dtype=torch.float64)
148
+ unmixing_logdet = torch.as_tensor(unmixing_logdet, dtype=torch.float64)
149
+ if model_weight <= 0:
150
+ raise ValueError(
151
+ f"model_weight must be > 0, got {model_weight}"
152
+ ) # pragma no cover noqa: E501
153
+ #--------------------------FORTRAN CODE-------------------------
154
+ # Ptmp(bstrt:bstp,h) = Dsum(h) + log(gm(h)) + sldet
155
+ #---------------------------------------------------------------
156
+ initial_modloglik = unmixing_logdet + torch.log(model_weight) + whitening_logdet
157
+ return initial_modloglik
158
+
159
+
160
+ def pre_whiten(
161
+ *,
162
+ X: DataArray2D,
163
+ n_components: int | None = None,
164
+ mineig: float = 1e-6,
165
+ do_mean: bool = True,
166
+ do_sphere: bool = True,
167
+ do_approx_sphere: bool = True,
168
+ inplace: bool = True,
169
+ ) -> tuple[DataArray2D, WeightsArray, float, WeightsArray, ComponentsVector | None]:
170
+ """
171
+ Pre-whiten the input data matrix X prior to ICA.
172
+
173
+ Parameters
174
+ ----------
175
+ X : array, shape (``n_samples``, ``n_features``)
176
+ Input data matrix to be whitened. If ``inplace`` is ``True``, X will be
177
+ mutated and returned as the whitened data. Otherwise a copy will be made and
178
+ returned.
179
+ n_components : int or None
180
+ Number of components to keep. If ``None``, all components are kept.
181
+ mineig : float
182
+ Minimum eigenvalue threshold for keeping components. Eigenvalues below this will
183
+ be discarded.
184
+ do_mean : bool
185
+ If ``True``, mean-center the data before whitening.
186
+ do_sphere : bool
187
+ If ``True``, perform sphering (whitening). If ``False``, only variance
188
+ normalization is performed.
189
+ do_approx_sphere : bool
190
+ If ``True``, Data is whitened with the inverse of the symmetric square root of
191
+ the covariance matrix (ZCA whitening). If ``False``, PCA whitening is
192
+ performed. Only used if ``do_sphere`` is ``True``.
193
+ inplace : bool
194
+ If ``True``, modify X in place. If ``False``, make a copy of X and modify that.
195
+
196
+ Returns
197
+ -------
198
+ X : array, shape (``n_samples``, ``n_features``)
199
+ The whitened data matrix. This is a copy of the input data if inplace is False,
200
+ otherwise it is the mutated input data itself.
201
+ whitening_matrix : array, shape (``n_features``, ``n_features``)
202
+ The whitening/sphering matrix applied to the data. If do_sphere is False, then
203
+ this is the variance normalization matrix.
204
+ sldet : float
205
+ The log-determinant of the whitening matrix.
206
+ whitening_inverse : array, shape (``n_features``, ``n_features``)
207
+ The pseudoinverse of the whitening matrix. Only returned if do_sphere is True.
208
+ otherwise None.
209
+ mean : array, shape (``n_features``,)
210
+ The mean of each feature that was subtracted if ``do_mean`` is ``True``. Only
211
+ returned if ``do_mean`` is ``True``, otherwise ``None``.
212
+ """
213
+ dataseg = X if inplace else X.copy()
214
+ assert dataseg.ndim == 2, f"X must be 2D, got {dataseg.ndim}D"
215
+ # !---------------------------- get the mean --------------------------------
216
+ n_samples, nx = dataseg.shape
217
+ if n_components is None:
218
+ n_components = nx
219
+
220
+ # ---- Mean-centering ----
221
+ if do_mean:
222
+ log("getting the mean ...")
223
+ mean = dataseg.mean(axis=0)
224
+ # !--- subtract the mean
225
+ dataseg -= mean[None, :] # Subtract mean from each channel
226
+
227
+ # ---- Covariance ----
228
+ log(" Getting the covariance matrix ...")
229
+ # Compute the covariance matrix
230
+ # The Fortran code only computes the upper triangular part of the covariance matrix
231
+
232
+ # -------------------- FORTRAN CODE ------------------------------------------------
233
+ # call DSCAL(nx*nx,dble(0.0),Stmp,1)
234
+ # call DSYRK('L','N',nx,blk_size(seg),dble(1.0),dataseg(seg)%data(:,bstrt:bstp)...
235
+ # call DSCAL(nx*nx,dble(1.0)/dble(cnt),S,1)
236
+ #-----------------------------------------------------------------------------------
237
+ Cov = dataseg.T @ dataseg / n_samples
238
+
239
+ # ---- Eigen-decomposition
240
+ log(f"doing eigenvalue decomposition for {nx} features ...")
241
+ eigvals, eigvecs = np.linalg.eigh(Cov) # ascending order
242
+
243
+ min_eigs = eigvals[:min(nx//2, 3)]
244
+ max_eigs = eigvals[::-1][:3]
245
+ log(f"minimum eigenvalues: {min_eigs}")
246
+ log(f"maximum eigenvalues: {max_eigs}")
247
+
248
+ # keep only valid eigs (pcakeep)
249
+ # Do we need to pass numeigs to optimize if sum(eigvals > mineig) < n_components?
250
+ numeigs = min(n_components, sum(eigvals > mineig)) # np.linalg.matrix_rank?
251
+ log(f"num eigvals kept: {numeigs}")
252
+
253
+ # Log determinant of the whitening matrix
254
+ if numeigs == nx:
255
+ sldet = -0.5 * np.sum(np.log(eigvals))
256
+ else:
257
+ sldet = -0.5 * np.sum(np.log(eigvals[::-1][:numeigs]))
258
+
259
+ # 1) reorder eigenvectors (descending eigenvalues)
260
+ order = np.argsort(eigvals)[::-1]
261
+ eigvals_desc = eigvals[order]
262
+ # UNCSCALED eigenvectors
263
+ Stmp = eigvecs[:, order].T.copy()
264
+ Stmp2 = Stmp.copy()
265
+ # 2) SCALED eigenvectors
266
+ Stmp2[:numeigs, :] /= np.sqrt(eigvals_desc[:numeigs, None])
267
+
268
+ # ---- Sphere or variance normalize ----
269
+ if do_sphere:
270
+ log("Sphering the data...", level="info", color="blue", weight="bold")
271
+ if numeigs == nx:
272
+ # call DSCAL(nx*nx,dble(0.0),S,1)
273
+ if do_approx_sphere:
274
+ # Zero-copy assignment
275
+ S = (eigvecs * (1.0 / np.sqrt(eigvals))) @ eigvecs.T
276
+ else:
277
+ # call DCOPY(nx*nx,Stmp2,1,S,1)
278
+ S = Stmp2.copy()
279
+ else:
280
+ if do_approx_sphere:
281
+ # -------------------- FORTRAN CODE ------------------------------------
282
+ # call DSCAL(nx*blk_size(seg),dble(0.0),xtmp(:,1:blk_size(seg)),1)
283
+ # call DGEMM('N','N',nx,blk_size(seg),nx,dble(1.0),S,nx,dataseg(seg)...
284
+ # call DCOPY(nx*blk_size(seg),xtmp(:,1:blk_size(seg)),1,dataseg(seg)
285
+ # ----------------------------------------------------------------------
286
+
287
+ # This is a direct translation of the Fortran code
288
+ # Which was hard to follow...
289
+ # I think this is ZCA whitening
290
+
291
+ # 3) Unscaled eigenvectors goes into SVD
292
+ S = Stmp.copy()
293
+
294
+ # 4) SVD on leading block of Unscaled eigenvectors
295
+ U, s, VT = np.linalg.svd(S[:numeigs, :numeigs], full_matrices=True)
296
+ Stmp[:numeigs, :numeigs] = VT.T
297
+ S[:numeigs, :numeigs] = U.T
298
+
299
+ # 5) Stmp3 = Stmp^T @ S^T (numeigs×numeigs block)
300
+ Stmp3 = Stmp[:numeigs, :numeigs] @ S[:numeigs, :numeigs]
301
+
302
+ # 6) zero S and form final S = Stmp3 @ Stmp2
303
+ S.fill(0.0)
304
+ S[:numeigs, :] = Stmp3 @ Stmp2[:numeigs, :]
305
+ else:
306
+ # I think this is PCA whitening
307
+ S = Stmp2.copy()
308
+ else:
309
+ # !--- just normalize by the channel variances (don't sphere)
310
+ # -------------------- FORTRAN CODE ---------------------------------------
311
+ # call DCOPY(nx*nx,S,1,Stmp,1)
312
+ # call DSCAL(nx*nx,dble(0.0),S,1)
313
+ #------------------------------------------------------------------------
314
+ S = np.zeros_like(Cov) # This is S in Fortran code
315
+ # Zero out the lower triangle to have parity with Fortran
316
+ sldet = 0.0
317
+ for i in range(nx):
318
+ if np.triu(Cov)[i, i] > 0:
319
+ S[i, i] = 1.0 / np.sqrt(Cov[i, i])
320
+ sldet += 0.5 * np.log(S[i, i])
321
+ numeigs = nx
322
+ # -------------------- FORTRAN CODE ---------------------------------------
323
+ # call DSCAL(nx*blk_size(seg),dble(0.0),xtmp(:,1:blk_size(seg)),1)
324
+ # call DGEMM('N','N',nx,blk_size(seg),nx,dble(1.0),S,nx,dataseg(seg)%data(:,bstrt...
325
+ # call DCOPY(nx*blk_size(seg),xtmp(:,1:blk_size(seg)),1,dataseg(seg)%data(:,bstrt...
326
+ # -------------------------------------------------------------------------
327
+ dataseg = np.matmul(dataseg, S.T, out=dataseg) # In-place if possible
328
+
329
+ nw = numeigs # Number of weights, as per Fortran code
330
+ log(f"numeigs = {numeigs}, nw = {nw}")
331
+ # ! get the pseudoinverse of the sphering matrix
332
+ # call DGESVD( 'A', 'S', numeigs, nx, Stmp2, nx, eigvals, sUtmp, numeigs, sVtmp...
333
+ Winv = (eigvecs * np.sqrt(eigvals)) @ eigvecs.T # Inverse of the whitening matrix
334
+
335
+ if n_components is None:
336
+ n_components = nw
337
+ elif n_components > nw:
338
+ raise ValueError(
339
+ f"n_components must be less than or equal to the rank of the data. "
340
+ f"Got a rank of {nw} but {n_components} requested components."
341
+ )
342
+
343
+ if not do_mean:
344
+ mean = None
345
+ assert dataseg.shape == (n_samples, nx), (
346
+ f"dataseg shape {dataseg.shape} "
347
+ "!= (n_samples, n_features) = ({n_samples}, {nx}) "
348
+ )
349
+ return dataseg, S, sldet, Winv, mean