amica-python 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amica/__init__.py +5 -0
- amica/_batching.py +194 -0
- amica/_newton.py +77 -0
- amica/_sklearn_interface.py +387 -0
- amica/_types.py +44 -0
- amica/conftest.py +30 -0
- amica/constants.py +47 -0
- amica/core.py +1165 -0
- amica/datasets.py +15 -0
- amica/kernels.py +1308 -0
- amica/linalg.py +349 -0
- amica/state.py +385 -0
- amica/tests/test_amica.py +497 -0
- amica/utils/__init__.py +36 -0
- amica/utils/_logging.py +64 -0
- amica/utils/_progress.py +34 -0
- amica/utils/_verbose.py +14 -0
- amica/utils/fetch.py +274 -0
- amica/utils/fortran.py +387 -0
- amica/utils/imports.py +46 -0
- amica/utils/mne.py +74 -0
- amica/utils/parallel.py +72 -0
- amica/utils/simulation.py +36 -0
- amica/utils/tests/test_fetch.py +9 -0
- amica/utils/tests/test_fortran.py +47 -0
- amica/utils/tests/test_imports.py +0 -0
- amica/utils/tests/test_logger.py +29 -0
- amica/utils/tests/test_mne.py +27 -0
- amica_python-0.1.0.dist-info/METADATA +196 -0
- amica_python-0.1.0.dist-info/RECORD +33 -0
- amica_python-0.1.0.dist-info/WHEEL +5 -0
- amica_python-0.1.0.dist-info/licenses/LICENSE +25 -0
- amica_python-0.1.0.dist-info/top_level.txt +1 -0
amica/core.py
ADDED
|
@@ -0,0 +1,1165 @@
|
|
|
1
|
+
"""Module containing amica funciton entry point."""
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from numpy.testing import assert_allclose
|
|
7
|
+
|
|
8
|
+
from amica._types import (
|
|
9
|
+
DataTensor2D,
|
|
10
|
+
)
|
|
11
|
+
from amica.constants import (
|
|
12
|
+
doscaling,
|
|
13
|
+
epsdble,
|
|
14
|
+
invsigmax,
|
|
15
|
+
invsigmin,
|
|
16
|
+
lratefact,
|
|
17
|
+
maxdecs,
|
|
18
|
+
maxincs,
|
|
19
|
+
maxrho,
|
|
20
|
+
mineig,
|
|
21
|
+
minlog,
|
|
22
|
+
minlrate,
|
|
23
|
+
minrho,
|
|
24
|
+
outstep,
|
|
25
|
+
rholratefact,
|
|
26
|
+
share_comps,
|
|
27
|
+
share_iter,
|
|
28
|
+
share_start,
|
|
29
|
+
use_grad_norm,
|
|
30
|
+
use_min_dll,
|
|
31
|
+
)
|
|
32
|
+
from amica.kernels import (
|
|
33
|
+
accumulate_alpha_stats,
|
|
34
|
+
accumulate_beta_stats,
|
|
35
|
+
accumulate_c_stats,
|
|
36
|
+
accumulate_kappa_stats,
|
|
37
|
+
accumulate_lambda_stats,
|
|
38
|
+
accumulate_mu_stats,
|
|
39
|
+
accumulate_rho_stats,
|
|
40
|
+
accumulate_sigma2_stats,
|
|
41
|
+
compute_mixture_responsibilities,
|
|
42
|
+
compute_model_loglikelihood_per_sample,
|
|
43
|
+
compute_model_responsibilities,
|
|
44
|
+
compute_preactivations,
|
|
45
|
+
compute_scaled_scores,
|
|
46
|
+
compute_source_densities,
|
|
47
|
+
compute_source_scores,
|
|
48
|
+
compute_total_loglikelihood_per_sample,
|
|
49
|
+
compute_weighted_responsibilities,
|
|
50
|
+
precompute_weighted_scores,
|
|
51
|
+
)
|
|
52
|
+
from amica.linalg import (
|
|
53
|
+
compute_sign_log_determinant,
|
|
54
|
+
get_initial_model_log_likelihood,
|
|
55
|
+
get_unmixing_matrices,
|
|
56
|
+
pre_whiten,
|
|
57
|
+
)
|
|
58
|
+
from amica.state import (
|
|
59
|
+
AmicaAccumulators,
|
|
60
|
+
AmicaConfig,
|
|
61
|
+
AmicaState,
|
|
62
|
+
IterationMetrics,
|
|
63
|
+
get_initial_state,
|
|
64
|
+
initialize_accumulators,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
from ._batching import BatchLoader, choose_batch_size
|
|
68
|
+
from ._newton import compute_newton_terms
|
|
69
|
+
from .utils._logging import log, set_log_level
|
|
70
|
+
from .utils._progress import make_progress_bar
|
|
71
|
+
from .utils._verbose import _validate_verbose
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def fit_amica(
|
|
75
|
+
X,
|
|
76
|
+
*,
|
|
77
|
+
whiten="zca",
|
|
78
|
+
mean_center=True,
|
|
79
|
+
n_components=None,
|
|
80
|
+
device="cpu",
|
|
81
|
+
n_mixtures=3,
|
|
82
|
+
max_iter=500,
|
|
83
|
+
tol=1e-7,
|
|
84
|
+
lrate=0.05,
|
|
85
|
+
rholrate=0.05,
|
|
86
|
+
pdftype=0,
|
|
87
|
+
do_newton=True,
|
|
88
|
+
newt_start=50,
|
|
89
|
+
newtrate=1.0,
|
|
90
|
+
newt_ramp=10,
|
|
91
|
+
batch_size=None,
|
|
92
|
+
w_init=None,
|
|
93
|
+
sbeta_init=None,
|
|
94
|
+
mu_init=None,
|
|
95
|
+
do_reject=False,
|
|
96
|
+
random_state=None,
|
|
97
|
+
verbose=1,
|
|
98
|
+
):
|
|
99
|
+
"""Perform Adaptive Mixture Independent Component Analysis (AMICA).
|
|
100
|
+
|
|
101
|
+
Implements the AMICA algorithm as described in :footcite:t:`palmer2012` and
|
|
102
|
+
:footcite:t:`palmer2008`, and originally implemented in :footcite:t:`amica`.
|
|
103
|
+
|
|
104
|
+
Parameters
|
|
105
|
+
----------
|
|
106
|
+
X : array-like, shape (``n_samples``, ``n_features``)
|
|
107
|
+
Training data, where ``n_samples`` is the number of samples and
|
|
108
|
+
``n_features`` is the number of features.
|
|
109
|
+
n_components : int, optional
|
|
110
|
+
Number of components to extract. If ``None`` (default), set to ``n_features``.
|
|
111
|
+
Note that the number of components may be reduced during whitening if the data
|
|
112
|
+
are rank-deficient.
|
|
113
|
+
n_mixtures: int, optional, default=3
|
|
114
|
+
Number of mixtures components to use in the Gaussian Mixture Model (GMM) for
|
|
115
|
+
each component's source density. default is ``3``.
|
|
116
|
+
batch_size : int, optional
|
|
117
|
+
Batch size for processing data in chunks along the samples axis. If ``None``,
|
|
118
|
+
the batch size is chosen automatically to keep peak memory under ~1.5 GB, and
|
|
119
|
+
warns if the batch size is below ~8k samples. If the input data is small enough
|
|
120
|
+
to process in one shot, no batching is used. If you want to enforce no
|
|
121
|
+
batching, you can override this memory cap by setting batch_size explicitly,
|
|
122
|
+
e.g. to `X.shape[0]` to process all samples at once. but note that this may
|
|
123
|
+
lead to high memory usage for large datasets.
|
|
124
|
+
device : str, optional
|
|
125
|
+
Device to run the computations on. Can be either 'cpu' or 'cuda' for GPU
|
|
126
|
+
acceleration. Note that using 'cuda' requires a compatible NVIDIA GPU and
|
|
127
|
+
the appropriate CUDA drivers installed.
|
|
128
|
+
whiten : str {"zca", "pca", "variance"}
|
|
129
|
+
Whitening method to apply to the data before fitting AMICA. Options are:
|
|
130
|
+
- "zca": Zero-phase component analysis (ZCA) whitening.
|
|
131
|
+
- "pca": Principal component analysis (PCA) whitening.
|
|
132
|
+
- "variance": Only variance normalization of the features is done (no sphering).
|
|
133
|
+
mean_center : bool, optional
|
|
134
|
+
If ``True``, X is mean corrected.
|
|
135
|
+
max_iter : int, optional
|
|
136
|
+
Maximum number of iterations to perform. Default is ``500``.
|
|
137
|
+
random_state : int or None, optional (default=None)
|
|
138
|
+
Used to perform a random initialization when w_init is not provided.
|
|
139
|
+
If int, random_state is the seed used by the random number generator during
|
|
140
|
+
whitening, and is used to set the seed during optimization initialization.
|
|
141
|
+
w_init : array-like, shape (``n_components``, ``n_components``), optional
|
|
142
|
+
Initial weights for the mixture components. If None, weights are initialized
|
|
143
|
+
randomly. This is meant to be used for testing and debugging purposes only.
|
|
144
|
+
sbeta_init : array-like, shape (``n_components``, ``n_mixtures``), optional
|
|
145
|
+
Initial scales (sbeta) for the mixture components. If None, scales are
|
|
146
|
+
initialized randomly. This is meant to be used for testing and debugging
|
|
147
|
+
purposes only.
|
|
148
|
+
mu_init : array-like, shape (``n_components``, ``n_mixtures``), optional
|
|
149
|
+
Initial locations (mu) for the mixture components. If None, locations are
|
|
150
|
+
initialized randomly. This is meant to be used for testing and debugging
|
|
151
|
+
purposes only.
|
|
152
|
+
lrate : float, default=0.05
|
|
153
|
+
Initial learning rate for the natural gradient.
|
|
154
|
+
rholrate : float = default=0.05
|
|
155
|
+
initial learning rate for shape parameters.
|
|
156
|
+
pdftype : int, default=0
|
|
157
|
+
Type of source density model to use. Currently only ``0`` is supported,
|
|
158
|
+
which corresponds to the Gaussian Mixture Model (GMM) density.
|
|
159
|
+
do_newton : bool, default=True
|
|
160
|
+
If ``True``, the optimization method will switch from Stochastic Gradient
|
|
161
|
+
Descent (SGD) to newton updates after ``newt_start`` iterations. If ``False``,
|
|
162
|
+
only SGD updates are used.
|
|
163
|
+
newt_start : int, default=50
|
|
164
|
+
Number of iterations before switching to Newton updates if ``do_newton`` is
|
|
165
|
+
``True``.
|
|
166
|
+
newtrate : float, default=1.0
|
|
167
|
+
learning rate for newton iterations.
|
|
168
|
+
verbose : int, default=1
|
|
169
|
+
Output mode during optimization:
|
|
170
|
+
|
|
171
|
+
- ``0``: silent
|
|
172
|
+
- ``1``: progress bar
|
|
173
|
+
- ``2``: per-iteration FORTRAN-style logs
|
|
174
|
+
|
|
175
|
+
Returns
|
|
176
|
+
-------
|
|
177
|
+
results : dict
|
|
178
|
+
Dictionary containing the following entries:
|
|
179
|
+
|
|
180
|
+
- mean : array, shape (``n_features``,) | ``None``
|
|
181
|
+
The mean over features. if ``do_mean=False``, this is ``None``.
|
|
182
|
+
- S : array, shape (``n_components``, ``n_features``)
|
|
183
|
+
The sphering (whitening) matrix applied to the data.
|
|
184
|
+
- W : array, shape (``n_components``, ``n_components``)
|
|
185
|
+
The unmixing matrix.
|
|
186
|
+
- A : array, shape (``n_components``, ``n_components``)
|
|
187
|
+
The mixing matrix in the space of sphered data. To get the mixing matrix
|
|
188
|
+
in the original data space, use ``np.linalg.pinv(S) @ A``.
|
|
189
|
+
- LL : array, shape (``max_iter``,)
|
|
190
|
+
The log-likelihood values at each iteration. If the algorithm converged
|
|
191
|
+
before reaching ``max_iter``, the remaining entries will be zero.
|
|
192
|
+
- gm : array, shape (1,)
|
|
193
|
+
The Gaussian mixture model weights. Since only one model is supported,
|
|
194
|
+
this will be of shape (1,).
|
|
195
|
+
- mu : array, shape (``n_components``, ``n_mixtures``)
|
|
196
|
+
The location parameters for the mixture components, i.e. the means of the
|
|
197
|
+
mixture components.
|
|
198
|
+
- rho : array, shape (``n_components``, ``n_mixtures``)
|
|
199
|
+
The shape parameters for the mixture components.
|
|
200
|
+
- sbeta : array, shape (``n_components``, ``n_mixtures``)
|
|
201
|
+
The scale (precision) parameters for the mixture components.
|
|
202
|
+
- alpha : array, shape (``n_components``, ``n_mixtures``)
|
|
203
|
+
The mixture weights for the mixture components.
|
|
204
|
+
- c : array, shape (``n_components``,)
|
|
205
|
+
The model bias terms.
|
|
206
|
+
|
|
207
|
+
Notes
|
|
208
|
+
-----
|
|
209
|
+
In Fortran AMICA, ``alpha``, ``sbeta``, ``mu``, and ``rho`` are of shape
|
|
210
|
+
(``n_mixtures``, ``n_components``) (transposed compared to here).
|
|
211
|
+
|
|
212
|
+
References
|
|
213
|
+
----------
|
|
214
|
+
.. footbibliography::
|
|
215
|
+
|
|
216
|
+
"""
|
|
217
|
+
verbose = _validate_verbose(verbose)
|
|
218
|
+
set_log_level("INFO" if verbose == 2 else "ERROR")
|
|
219
|
+
|
|
220
|
+
if batch_size is None:
|
|
221
|
+
batch_size = choose_batch_size(
|
|
222
|
+
N=X.shape[0],
|
|
223
|
+
n_comps=n_components if n_components is not None else X.shape[1],
|
|
224
|
+
n_mix=n_mixtures,
|
|
225
|
+
)
|
|
226
|
+
# Step 1: Create config and state objects (new dataclass approach)
|
|
227
|
+
config = AmicaConfig(
|
|
228
|
+
n_features=X.shape[1], # Number of channels (corrected from X.shape[1])
|
|
229
|
+
n_components=n_components if n_components is not None else X.shape[1],
|
|
230
|
+
n_models=1,
|
|
231
|
+
n_mixtures=n_mixtures,
|
|
232
|
+
max_iter=max_iter,
|
|
233
|
+
batch_size=batch_size,
|
|
234
|
+
device=torch.device(device),
|
|
235
|
+
pdftype=pdftype,
|
|
236
|
+
tol=tol,
|
|
237
|
+
lrate=lrate,
|
|
238
|
+
rholrate=rholrate,
|
|
239
|
+
do_newton=do_newton,
|
|
240
|
+
newt_start=newt_start,
|
|
241
|
+
newtrate=newtrate,
|
|
242
|
+
newt_ramp=newt_ramp,
|
|
243
|
+
do_reject=do_reject,
|
|
244
|
+
verbose=verbose,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# Step 2: Create initial state (this will eventually replace manual initialization)
|
|
248
|
+
torch.set_default_dtype(config.dtype) # TODO: Make this less global
|
|
249
|
+
state = get_initial_state(config)
|
|
250
|
+
|
|
251
|
+
# Init
|
|
252
|
+
if config.do_reject:
|
|
253
|
+
raise NotImplementedError(
|
|
254
|
+
"Sample rejection by log likelihood is not yet supported."
|
|
255
|
+
) # pragma: no cover
|
|
256
|
+
dataseg = X.copy()
|
|
257
|
+
|
|
258
|
+
# Whitening
|
|
259
|
+
do_sphere = True if whiten in {"zca", "pca"} else False
|
|
260
|
+
do_approx_sphere = True if whiten == "zca" else False
|
|
261
|
+
do_mean = True if mean_center else False
|
|
262
|
+
dataseg, whitening_matrix, sldet, whitening_inverse, mean = pre_whiten(
|
|
263
|
+
X=dataseg,
|
|
264
|
+
n_components=n_components,
|
|
265
|
+
mineig=mineig,
|
|
266
|
+
do_mean=do_mean,
|
|
267
|
+
do_sphere=do_sphere,
|
|
268
|
+
do_approx_sphere=do_approx_sphere,
|
|
269
|
+
inplace=True,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Run AMICA
|
|
273
|
+
state_dict, LL = solve(
|
|
274
|
+
X=dataseg,
|
|
275
|
+
config=config,
|
|
276
|
+
state=state,
|
|
277
|
+
sldet=sldet,
|
|
278
|
+
random_state=random_state,
|
|
279
|
+
initial_weights=w_init,
|
|
280
|
+
initial_scales=sbeta_init,
|
|
281
|
+
initial_locations=mu_init,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
return dict(
|
|
285
|
+
S=whitening_matrix,
|
|
286
|
+
mean=mean,
|
|
287
|
+
gm=state_dict["gm"],
|
|
288
|
+
mu=state_dict["mu"],
|
|
289
|
+
rho=state_dict["rho"],
|
|
290
|
+
sbeta=state_dict["sbeta"],
|
|
291
|
+
W=state_dict["W"],
|
|
292
|
+
A=state_dict["A"],
|
|
293
|
+
c=state_dict["c"],
|
|
294
|
+
alpha=state_dict["alpha"],
|
|
295
|
+
LL=LL,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
def solve(
|
|
299
|
+
X,
|
|
300
|
+
*,
|
|
301
|
+
config,
|
|
302
|
+
state,
|
|
303
|
+
sldet,
|
|
304
|
+
random_state=None,
|
|
305
|
+
initial_weights=None,
|
|
306
|
+
initial_scales=None,
|
|
307
|
+
initial_locations=None,
|
|
308
|
+
):
|
|
309
|
+
"""Run the AMICA algorithm.
|
|
310
|
+
|
|
311
|
+
Parameters
|
|
312
|
+
----------
|
|
313
|
+
X : array, shape (N, T)
|
|
314
|
+
Matrix containing the features that have to be unmixed. N is the
|
|
315
|
+
number of features, T is the number of samples. X has to be centered
|
|
316
|
+
initial_weights : array-like, shape (n_components, n_components), optional
|
|
317
|
+
Initial weights for the mixture components. If None, weights are initialized
|
|
318
|
+
randomly. This is meant to be used for testing and debugging purposes only.
|
|
319
|
+
initial_scales : array-like, shape (n_components, n_mixtures), optional
|
|
320
|
+
Initial scales (sbeta) for the mixture components. If None, scales are
|
|
321
|
+
initialized randomly. This is meant to be used for testing and debugging
|
|
322
|
+
purposes only.
|
|
323
|
+
initial_locations : array-like, shape (n_components, n_mixtures), optional
|
|
324
|
+
Initial locations (mu) for the mixture components. If None, locations are
|
|
325
|
+
initialized randomly. This is meant to be used for testing and debugging
|
|
326
|
+
purposes only.
|
|
327
|
+
"""
|
|
328
|
+
# No-copy (if on CPU)
|
|
329
|
+
X: DataTensor2D = torch.as_tensor(X, dtype=config.dtype, device=config.device)
|
|
330
|
+
rng = torch.Generator()
|
|
331
|
+
if random_state is not None:
|
|
332
|
+
rng.manual_seed(random_state)
|
|
333
|
+
# The API will use n_components but under the hood we'll match the Fortran naming
|
|
334
|
+
# TODO: Maybe rename n_components to num_comps in the config dataclass?
|
|
335
|
+
num_comps = config.n_components
|
|
336
|
+
num_mix = config.n_mixtures
|
|
337
|
+
# !-------------------- ALLOCATE VARIABLES ---------------------
|
|
338
|
+
|
|
339
|
+
# !------------------- INITIALIZE VARIABLES ----------------------
|
|
340
|
+
# print *, myrank+1, ': Initializing variables ...'; call flush(6);
|
|
341
|
+
# if (seg_rank == 0) then
|
|
342
|
+
|
|
343
|
+
assert_allclose(state.gm.sum(), 1.0)
|
|
344
|
+
# load_alpha:
|
|
345
|
+
state.alpha[:, :num_mix] = 1.0 / num_mix
|
|
346
|
+
# load_mu:
|
|
347
|
+
mu_values = torch.arange(num_mix) - (num_mix - 1) / 2
|
|
348
|
+
state.mu[:, :] = mu_values[None, :]
|
|
349
|
+
if initial_locations is None:
|
|
350
|
+
initial_locations = torch.rand(num_comps, num_mix, generator=rng)
|
|
351
|
+
else:
|
|
352
|
+
assert initial_locations.shape == (num_comps, num_mix)
|
|
353
|
+
initial_locations = torch.as_tensor(initial_locations, dtype=torch.float64)
|
|
354
|
+
state.mu = state.mu + 0.05 * (1.0 - 2.0 * initial_locations)
|
|
355
|
+
# load_beta:
|
|
356
|
+
if initial_scales is None:
|
|
357
|
+
initial_scales = torch.rand(num_comps, num_mix, generator=rng)
|
|
358
|
+
else:
|
|
359
|
+
assert initial_scales.shape == (num_comps, num_mix)
|
|
360
|
+
initial_scales = torch.as_tensor(initial_scales, dtype=torch.float64)
|
|
361
|
+
state.sbeta = 1.0 + 0.1 * (0.5 - initial_scales)
|
|
362
|
+
# load_c:
|
|
363
|
+
state.c.fill_(0.0)
|
|
364
|
+
|
|
365
|
+
# load_A:
|
|
366
|
+
if initial_weights is None:
|
|
367
|
+
initial_weights = torch.rand(num_comps, num_comps, generator=rng)
|
|
368
|
+
else:
|
|
369
|
+
assert initial_weights.shape == (num_comps, num_comps)
|
|
370
|
+
initial_weights = torch.as_tensor(initial_weights, dtype=torch.float64)
|
|
371
|
+
|
|
372
|
+
state.A[:, :] = 0.01 * (0.5 - initial_weights)
|
|
373
|
+
idx = torch.arange(num_comps)
|
|
374
|
+
state.A[idx, idx] = 1.0
|
|
375
|
+
Anrmk = torch.linalg.norm(state.A[:, :], dim=0)
|
|
376
|
+
state.A[:, :] /= Anrmk
|
|
377
|
+
# end load_A
|
|
378
|
+
|
|
379
|
+
W, wc = get_unmixing_matrices(
|
|
380
|
+
c=state.c,
|
|
381
|
+
A=state.A,
|
|
382
|
+
W=state.W,
|
|
383
|
+
)
|
|
384
|
+
assert W.dtype == torch.float64
|
|
385
|
+
state.W = W.clone()
|
|
386
|
+
del W # safe guard against accidental use of W instead of state.W
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
# !-------------------- Determine optimal block size -------------------
|
|
390
|
+
log(f"1: block size = {config.batch_size}", level="info", color=None)
|
|
391
|
+
|
|
392
|
+
# !XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX main loop XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
|
393
|
+
log(
|
|
394
|
+
"Solving. (please be patient, this may take a while)...",
|
|
395
|
+
level="info",
|
|
396
|
+
color="blue",
|
|
397
|
+
weight="bold"
|
|
398
|
+
)
|
|
399
|
+
with torch.no_grad():
|
|
400
|
+
state, LL = optimize(
|
|
401
|
+
X=X,
|
|
402
|
+
sldet=sldet.item(),
|
|
403
|
+
wc=wc,
|
|
404
|
+
config=config,
|
|
405
|
+
state=state,
|
|
406
|
+
)
|
|
407
|
+
# Convert Tensors to numpy arrays for output
|
|
408
|
+
state_dict = state.to_numpy()
|
|
409
|
+
LL = LL.cpu().numpy()
|
|
410
|
+
return state_dict, LL
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
def optimize(
|
|
414
|
+
*,
|
|
415
|
+
X: DataTensor2D,
|
|
416
|
+
sldet: float,
|
|
417
|
+
wc: torch.Tensor,
|
|
418
|
+
config: AmicaConfig,
|
|
419
|
+
state: AmicaState,
|
|
420
|
+
):
|
|
421
|
+
"""Optimize the learnable Paramters."""
|
|
422
|
+
# Just set all convergence creterion to the user specific tol
|
|
423
|
+
min_dll = config.tol
|
|
424
|
+
min_nd = config.tol
|
|
425
|
+
|
|
426
|
+
# These variables can be updated in the loop
|
|
427
|
+
leave = False
|
|
428
|
+
do_newton = config.do_newton
|
|
429
|
+
numdecs = 0 # number of consecutive iterations where LL decreased from previous
|
|
430
|
+
numincs = 0 # number of consecutive iterations where LL increased by less than tol
|
|
431
|
+
metrics = IterationMetrics(
|
|
432
|
+
iter=1,
|
|
433
|
+
lrate=config.lrate,
|
|
434
|
+
rholrate=config.rholrate,
|
|
435
|
+
lrate0=config.lrate, # updates slower than lrate..
|
|
436
|
+
rholrate0=config.rholrate, # Updates slower than rholrate..
|
|
437
|
+
newtrate=config.newtrate,
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
# Initialize accumulators container
|
|
441
|
+
accumulators = initialize_accumulators(config)
|
|
442
|
+
if config.device.type != "cpu":
|
|
443
|
+
state.to_device(device=config.device)
|
|
444
|
+
wc = wc.to(device=config.device)
|
|
445
|
+
# We allocate these separately.
|
|
446
|
+
Dsum = torch.tensor(0.0, dtype=torch.float64, device=config.device)
|
|
447
|
+
Dsign = torch.tensor(1.0, dtype=torch.float64, device=config.device)
|
|
448
|
+
# per sample loglik
|
|
449
|
+
loglik = torch.zeros((X.shape[0],), dtype=torch.float64, device=config.device)
|
|
450
|
+
# likelihood history
|
|
451
|
+
LL = torch.zeros(max(1, config.max_iter), dtype=torch.float64, device=config.device)
|
|
452
|
+
|
|
453
|
+
c_start = time.time()
|
|
454
|
+
c1 = time.time()
|
|
455
|
+
progress = None
|
|
456
|
+
task_id = None
|
|
457
|
+
if config.verbose == 1:
|
|
458
|
+
progress, task_id = make_progress_bar(
|
|
459
|
+
total=config.max_iter,
|
|
460
|
+
lrate=metrics.lrate,
|
|
461
|
+
)
|
|
462
|
+
try:
|
|
463
|
+
return _main_loop(
|
|
464
|
+
X=X,
|
|
465
|
+
sldet=sldet,
|
|
466
|
+
wc=wc,
|
|
467
|
+
config=config,
|
|
468
|
+
state=state,
|
|
469
|
+
do_newton=do_newton,
|
|
470
|
+
leave=leave,
|
|
471
|
+
numdecs=numdecs,
|
|
472
|
+
numincs=numincs,
|
|
473
|
+
metrics=metrics,
|
|
474
|
+
accumulators=accumulators,
|
|
475
|
+
Dsum=Dsum,
|
|
476
|
+
Dsign=Dsign,
|
|
477
|
+
loglik=loglik,
|
|
478
|
+
LL=LL,
|
|
479
|
+
c_start=c_start,
|
|
480
|
+
c1=c1,
|
|
481
|
+
progress=progress,
|
|
482
|
+
task_id=task_id,
|
|
483
|
+
min_dll=min_dll,
|
|
484
|
+
min_nd=min_nd,
|
|
485
|
+
)
|
|
486
|
+
finally:
|
|
487
|
+
if progress is not None:
|
|
488
|
+
progress.stop()
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def _main_loop(
|
|
492
|
+
*,
|
|
493
|
+
X: DataTensor2D,
|
|
494
|
+
sldet: float,
|
|
495
|
+
wc: torch.Tensor,
|
|
496
|
+
config: AmicaConfig,
|
|
497
|
+
state: AmicaState,
|
|
498
|
+
do_newton: bool,
|
|
499
|
+
leave: bool,
|
|
500
|
+
numdecs: int,
|
|
501
|
+
numincs: int,
|
|
502
|
+
metrics: IterationMetrics,
|
|
503
|
+
accumulators: AmicaAccumulators,
|
|
504
|
+
Dsum: torch.Tensor,
|
|
505
|
+
Dsign: torch.Tensor,
|
|
506
|
+
loglik: torch.Tensor,
|
|
507
|
+
LL: torch.Tensor,
|
|
508
|
+
c_start: float,
|
|
509
|
+
c1: float,
|
|
510
|
+
progress,
|
|
511
|
+
task_id,
|
|
512
|
+
min_dll: float,
|
|
513
|
+
min_nd: float,
|
|
514
|
+
):
|
|
515
|
+
"""Run the AMICA optimization loop and return updated state and LL history."""
|
|
516
|
+
while metrics.iter <= config.max_iter:
|
|
517
|
+
accumulators.reset()
|
|
518
|
+
loglik.fill_(0.0)
|
|
519
|
+
doing_newton = do_newton and (metrics.iter >= config.newt_start)
|
|
520
|
+
# !----- get determinants
|
|
521
|
+
# The Fortran code computed log|det(W)| indirectly via QR factorization
|
|
522
|
+
# We use slogdet on the original unmixing matrix to get sign and log|det|
|
|
523
|
+
_, Dsum = compute_sign_log_determinant(
|
|
524
|
+
unmixing_matrix=state.W,
|
|
525
|
+
minlog=minlog,
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
if config.do_reject:
|
|
529
|
+
raise NotImplementedError() # pragma: no cover
|
|
530
|
+
# !--------- loop over the blocks ----------
|
|
531
|
+
'''
|
|
532
|
+
# In Fortran, the OMP parallel region would start before the lines below.
|
|
533
|
+
# !$OMP PARALLEL DEFAULT(SHARED) &
|
|
534
|
+
# ...
|
|
535
|
+
# !print *, myrank+1, thrdnum+1, ': Inside openmp code ... '; call flush(6)
|
|
536
|
+
'''
|
|
537
|
+
|
|
538
|
+
# -- 0. Baseline terms for per-sample model log-likelihood --
|
|
539
|
+
initial = get_initial_model_log_likelihood(
|
|
540
|
+
unmixing_logdet=Dsum,
|
|
541
|
+
whitening_logdet=sldet,
|
|
542
|
+
model_weight=state.gm[0],
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
#=============================== Subsection ====================================
|
|
546
|
+
# === Begin chunk loop ===
|
|
547
|
+
# ==============================================================================
|
|
548
|
+
batch_loader = BatchLoader(X, axis=0, batch_size=config.batch_size)
|
|
549
|
+
for batch_idx, (data_batch, batch_indices) in enumerate(batch_loader):
|
|
550
|
+
|
|
551
|
+
# ======================================================================
|
|
552
|
+
# Expectation Step (E-step)
|
|
553
|
+
# ======================================================================
|
|
554
|
+
|
|
555
|
+
# 1. --- Compute source pre-activations
|
|
556
|
+
# !--- get b
|
|
557
|
+
if not state.W.device.type == data_batch.device.type:
|
|
558
|
+
raise ValueError(
|
|
559
|
+
f"Mismatch between state.W device ({state.W.device}) "
|
|
560
|
+
"and data_batch device ({data_batch.device})"
|
|
561
|
+
)
|
|
562
|
+
b = compute_preactivations(
|
|
563
|
+
X=data_batch,
|
|
564
|
+
unmixing_matrix=state.W,
|
|
565
|
+
bias=wc,
|
|
566
|
+
do_reject=config.do_reject,
|
|
567
|
+
n_weights=config.n_components,
|
|
568
|
+
)
|
|
569
|
+
# 2. --- Source densities, and per-sample mixture log-densities (logits)
|
|
570
|
+
y, z = compute_source_densities(
|
|
571
|
+
pdftype=config.pdftype,
|
|
572
|
+
b=b,
|
|
573
|
+
sbeta=state.sbeta,
|
|
574
|
+
mu=state.mu,
|
|
575
|
+
alpha=state.alpha,
|
|
576
|
+
rho=state.rho,
|
|
577
|
+
)
|
|
578
|
+
z0 = z # log densities (alias for clarity with Fortran code)
|
|
579
|
+
|
|
580
|
+
# 3. --- Aggregate mixture logits into per-sample model log likelihoods
|
|
581
|
+
modloglik = torch.full(
|
|
582
|
+
size=(data_batch.shape[0], 1),
|
|
583
|
+
fill_value=initial,
|
|
584
|
+
dtype=config.dtype,
|
|
585
|
+
device=config.device,
|
|
586
|
+
)
|
|
587
|
+
compute_model_loglikelihood_per_sample(
|
|
588
|
+
log_densities=z0,
|
|
589
|
+
out_modloglik=modloglik[:, 0],
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
# 4. -- Responsibilities within each component ---
|
|
593
|
+
# !--- get normalized z
|
|
594
|
+
z = compute_mixture_responsibilities(log_densities=z0, inplace=True)
|
|
595
|
+
z0 = None
|
|
596
|
+
del z0 # guard against use of stale name. z owns that memory
|
|
597
|
+
|
|
598
|
+
# 5. --- Across-model Responsibilities and Total Log-Likelihood ---
|
|
599
|
+
loglik[batch_indices] = compute_total_loglikelihood_per_sample(
|
|
600
|
+
modloglik=modloglik,
|
|
601
|
+
out_loglik=loglik[batch_indices]
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
if config.do_reject:
|
|
605
|
+
raise NotImplementedError() # pragma: no cover
|
|
606
|
+
else:
|
|
607
|
+
# 6. --- Responsibilities for each model ---
|
|
608
|
+
v = compute_model_responsibilities(
|
|
609
|
+
modloglik=modloglik,
|
|
610
|
+
out=modloglik, # reuse modloglik memory
|
|
611
|
+
)
|
|
612
|
+
modloglik = None
|
|
613
|
+
del modloglik # Guard. v owns that memory now
|
|
614
|
+
|
|
615
|
+
# ================================ M-STEP ==================================
|
|
616
|
+
# === Maximization-step: Parameter accumulators ===
|
|
617
|
+
# - Update parameters based on current responsibilities
|
|
618
|
+
# - Update unmixing matrices with gradient ascent and Newton-Raphson
|
|
619
|
+
# ==========================================================================
|
|
620
|
+
|
|
621
|
+
# !--- get g, u, ufp
|
|
622
|
+
#--------------------------FORTRAN CODE-------------------------
|
|
623
|
+
# vsum = sum( v(bstrt:bstp,h) )
|
|
624
|
+
# dgm_numer_tmp(h) = dgm_numer_tmp(h) + vsum
|
|
625
|
+
#---------------------------------------------------------------
|
|
626
|
+
model_resps = v[:, 0] # select responsibilities for this model
|
|
627
|
+
vsum = model_resps.sum()
|
|
628
|
+
|
|
629
|
+
# NOTE: u is a view of z, so changes to u will affect z (and vice versa)
|
|
630
|
+
u = compute_weighted_responsibilities(
|
|
631
|
+
mixture_responsibilities=z,
|
|
632
|
+
model_responsibilities=model_resps,
|
|
633
|
+
single_model=True,
|
|
634
|
+
)
|
|
635
|
+
z = None
|
|
636
|
+
del z # guard against use of stale name. u owns that memory now
|
|
637
|
+
usum = u.sum(dim=0) # shape: (nw, num_mix)
|
|
638
|
+
|
|
639
|
+
fp = compute_source_scores(
|
|
640
|
+
pdftype=config.pdftype,
|
|
641
|
+
y=y,
|
|
642
|
+
rho=state.rho,
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
# For SGD, fp only exists to get ufp. Lets overwrite it to save memory.
|
|
646
|
+
ufp = precompute_weighted_scores(
|
|
647
|
+
weighted_responsibilities=u,
|
|
648
|
+
scores=fp,
|
|
649
|
+
out_ufp=fp if not doing_newton else None,
|
|
650
|
+
)
|
|
651
|
+
if not doing_newton:
|
|
652
|
+
fp = None
|
|
653
|
+
del fp # End of life. ufp owns that memory now
|
|
654
|
+
|
|
655
|
+
g = compute_scaled_scores(
|
|
656
|
+
weighted_scores=ufp,
|
|
657
|
+
scales=state.sbeta,
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
# --- Stochastic Gradient Descent accumulators ---
|
|
661
|
+
# gm (model weights)
|
|
662
|
+
accumulators.dgm_numer[0] += vsum
|
|
663
|
+
# c (bias)
|
|
664
|
+
accumulate_c_stats(
|
|
665
|
+
X=data_batch,
|
|
666
|
+
model_responsibilities=model_resps,
|
|
667
|
+
vsum=vsum,
|
|
668
|
+
n_weights=config.n_components,
|
|
669
|
+
out_numer=accumulators.dc_numer,
|
|
670
|
+
out_denom=accumulators.dc_denom,
|
|
671
|
+
)
|
|
672
|
+
# Alpha (mixture weights)
|
|
673
|
+
accumulate_alpha_stats(
|
|
674
|
+
usum=usum,
|
|
675
|
+
vsum=vsum,
|
|
676
|
+
out_numer=accumulators.dalpha_numer,
|
|
677
|
+
out_denom=accumulators.dalpha_denom,
|
|
678
|
+
)
|
|
679
|
+
# Mu (location)
|
|
680
|
+
accumulate_mu_stats(
|
|
681
|
+
ufp=ufp,
|
|
682
|
+
y=y,
|
|
683
|
+
sbeta=state.sbeta,
|
|
684
|
+
rho=state.rho,
|
|
685
|
+
out_numer=accumulators.dmu_numer,
|
|
686
|
+
out_denom=accumulators.dmu_denom,
|
|
687
|
+
)
|
|
688
|
+
# Beta (scale/precision)
|
|
689
|
+
accumulate_beta_stats(
|
|
690
|
+
usum=usum,
|
|
691
|
+
rho=state.rho,
|
|
692
|
+
ufp=ufp,
|
|
693
|
+
y=y,
|
|
694
|
+
out_numer=accumulators.dbeta_numer,
|
|
695
|
+
out_denom=accumulators.dbeta_denom,
|
|
696
|
+
)
|
|
697
|
+
# Rho (shape parameter of generalized Gaussian)
|
|
698
|
+
accumulate_rho_stats(
|
|
699
|
+
y=y,
|
|
700
|
+
rho=state.rho,
|
|
701
|
+
u=u,
|
|
702
|
+
usum=usum,
|
|
703
|
+
epsdble=epsdble,
|
|
704
|
+
out_numer=accumulators.drho_numer,
|
|
705
|
+
out_denom=accumulators.drho_denom,
|
|
706
|
+
)
|
|
707
|
+
# --- Newton-Raphson accumulators ---
|
|
708
|
+
if do_newton and metrics.iter >= config.newt_start:
|
|
709
|
+
# NOTE: Fortran computes dsigma_* for all iters, but its unnecessary
|
|
710
|
+
# Sigma^2 accumulators (noise variance)
|
|
711
|
+
accumulate_sigma2_stats(
|
|
712
|
+
model_responsibilities=model_resps,
|
|
713
|
+
source_estimates=b,
|
|
714
|
+
vsum=vsum,
|
|
715
|
+
out_numer=accumulators.newton.dsigma2_numer,
|
|
716
|
+
out_denom=accumulators.newton.dsigma2_denom,
|
|
717
|
+
)
|
|
718
|
+
# Kappa accumulators (curvature terms for A)
|
|
719
|
+
accumulate_kappa_stats(
|
|
720
|
+
ufp=ufp,
|
|
721
|
+
fp=fp,
|
|
722
|
+
sbeta=state.sbeta,
|
|
723
|
+
usum=usum,
|
|
724
|
+
out_numer=accumulators.newton.dkappa_numer,
|
|
725
|
+
out_denom=accumulators.newton.dkappa_denom,
|
|
726
|
+
)
|
|
727
|
+
# Lambda accumulators (nonlinearity shape parameter)
|
|
728
|
+
accumulate_lambda_stats(
|
|
729
|
+
fp=fp,
|
|
730
|
+
y=y,
|
|
731
|
+
u=u,
|
|
732
|
+
usum=usum,
|
|
733
|
+
out_numer=accumulators.newton.dlambda_numer,
|
|
734
|
+
out_denom=accumulators.newton.dlambda_denom,
|
|
735
|
+
)
|
|
736
|
+
# (dbar)Alpha accumulators
|
|
737
|
+
accumulators.newton.dbaralpha_numer[:, :] += usum
|
|
738
|
+
accumulators.newton.dbaralpha_denom[:, :] += vsum
|
|
739
|
+
# end if (do_newton and iteration >= newt_start)
|
|
740
|
+
|
|
741
|
+
# if (print_debug .and. (blk == 1) .and. (thrdnum == 0)) then
|
|
742
|
+
# if update_A:
|
|
743
|
+
#--------------------------FORTRAN CODE--------------------------------
|
|
744
|
+
# call DSCAL(nw*nw,dble(0.0),Wtmp2(:,:,thrdnum+1),1)
|
|
745
|
+
# call DGEMM('T','N',nw,nw,tblksize,dble(1.0),g(bstrt:bstp,:),...
|
|
746
|
+
# dble(1.0),Wtmp2(:,:,thrdnum+1),nw)
|
|
747
|
+
# call DAXPY(nw*nw,dble(1.0),Wtmp2(:,:,thrdnum+1),1,dWtmp(:,:,h),1)
|
|
748
|
+
#----------------------------------------------------------------------
|
|
749
|
+
accumulators.dA[:, :] += torch.matmul(g.T, b)
|
|
750
|
+
# end do (blk)'
|
|
751
|
+
|
|
752
|
+
# In Fortran, the OMP parallel region is closed here
|
|
753
|
+
# !$OMP END PARALLEL
|
|
754
|
+
|
|
755
|
+
# End of these lifetimes
|
|
756
|
+
del b, g, u, ufp, usum, vsum, v, model_resps, y
|
|
757
|
+
if doing_newton:
|
|
758
|
+
del fp # already deleted if not doing_newton
|
|
759
|
+
|
|
760
|
+
likelihood, ndtmpsum = accum_updates_and_likelihood(
|
|
761
|
+
X=X,
|
|
762
|
+
config=config,
|
|
763
|
+
accumulators=accumulators,
|
|
764
|
+
state=state,
|
|
765
|
+
total_LL=loglik.sum(),
|
|
766
|
+
iteration=metrics.iter
|
|
767
|
+
)
|
|
768
|
+
metrics.loglik = likelihood
|
|
769
|
+
metrics.ndtmpsum = ndtmpsum
|
|
770
|
+
# return accumulators, metrics
|
|
771
|
+
|
|
772
|
+
# ==============================================================================
|
|
773
|
+
ndtmpsum = metrics.ndtmpsum
|
|
774
|
+
LL[metrics.iter - 1] = metrics.loglik
|
|
775
|
+
|
|
776
|
+
# !----- display log likelihood of data
|
|
777
|
+
# if (seg_rank == 0) then
|
|
778
|
+
c2 = time.time()
|
|
779
|
+
t0 = c2 - c1
|
|
780
|
+
# if (mod(iter,outstep) == 0) then
|
|
781
|
+
|
|
782
|
+
if progress is not None and task_id is not None:
|
|
783
|
+
progress.update(
|
|
784
|
+
task_id,
|
|
785
|
+
completed=metrics.iter,
|
|
786
|
+
ll=f"{float(LL[metrics.iter - 1]):.4f}",
|
|
787
|
+
nd=f"{float(ndtmpsum):.4f}",
|
|
788
|
+
lrate=f"{metrics.lrate:.5f}",
|
|
789
|
+
)
|
|
790
|
+
|
|
791
|
+
if config.verbose == 2 and (metrics.iter % outstep) == 0:
|
|
792
|
+
report = (
|
|
793
|
+
f"Iteration {metrics.iter}, "
|
|
794
|
+
f"lrate = {metrics.lrate:.5f}, "
|
|
795
|
+
f"LL = {LL[metrics.iter - 1]:.7f}, "
|
|
796
|
+
f"nd = {ndtmpsum:.7f}, D = {float(Dsum):.5f} "
|
|
797
|
+
f"took {t0:.2f} seconds"
|
|
798
|
+
)
|
|
799
|
+
log(msg=report, level="info", color=None)
|
|
800
|
+
c1 = time.time()
|
|
801
|
+
|
|
802
|
+
# !----- check whether likelihood is increasing
|
|
803
|
+
# if (seg_rank == 0) then
|
|
804
|
+
# ! if we get a NaN early, try to reinitialize and startover a few times
|
|
805
|
+
if torch.isnan(LL[metrics.iter - 1]):
|
|
806
|
+
raise RuntimeError(f"Log Likelihood is NaN at iteration {metrics.iter}")
|
|
807
|
+
# end if
|
|
808
|
+
if metrics.iter > 1:
|
|
809
|
+
if (LL[metrics.iter - 1] < LL[metrics.iter - 2]):
|
|
810
|
+
# assert 1 == 0
|
|
811
|
+
log("Likelihood decreasing!", level="warning", color="yellow")
|
|
812
|
+
if (metrics.lrate < minlrate) or (ndtmpsum <= min_nd):
|
|
813
|
+
leave = True
|
|
814
|
+
log(
|
|
815
|
+
"minimum change threshold met, exiting loop",
|
|
816
|
+
level="info",
|
|
817
|
+
color="green",
|
|
818
|
+
weight="bold"
|
|
819
|
+
)
|
|
820
|
+
else:
|
|
821
|
+
metrics.lrate *= lratefact
|
|
822
|
+
metrics.rholrate *= rholratefact
|
|
823
|
+
numdecs += 1
|
|
824
|
+
if numdecs >= maxdecs:
|
|
825
|
+
metrics.lrate0 *= lratefact
|
|
826
|
+
if metrics.iter > config.newt_start:
|
|
827
|
+
metrics.rholrate0 *= rholratefact
|
|
828
|
+
if config.do_newton and metrics.iter > config.newt_start:
|
|
829
|
+
log(
|
|
830
|
+
"Reducing maximum Newton lrate",
|
|
831
|
+
level="info",
|
|
832
|
+
color="blue"
|
|
833
|
+
)
|
|
834
|
+
metrics.newtrate *= lratefact
|
|
835
|
+
numdecs = 0
|
|
836
|
+
# end if (numdecs >= maxdecs)
|
|
837
|
+
# end if (lrate vs minlrate)
|
|
838
|
+
# end if LL
|
|
839
|
+
if use_min_dll:
|
|
840
|
+
if (LL[metrics.iter - 1] - LL[metrics.iter - 2]) < min_dll:
|
|
841
|
+
numincs += 1
|
|
842
|
+
if numincs > maxincs:
|
|
843
|
+
leave = True
|
|
844
|
+
log(
|
|
845
|
+
"Exiting because likelihood increasing by less than "
|
|
846
|
+
f"{min_dll} for more than {maxincs} iterations ...",
|
|
847
|
+
level="info",
|
|
848
|
+
color="green",
|
|
849
|
+
weight="bold"
|
|
850
|
+
)
|
|
851
|
+
else:
|
|
852
|
+
numincs = 0
|
|
853
|
+
else:
|
|
854
|
+
raise NotImplementedError() # pragma: no cover
|
|
855
|
+
if use_grad_norm:
|
|
856
|
+
if ndtmpsum < min_nd:
|
|
857
|
+
leave = True
|
|
858
|
+
log(
|
|
859
|
+
"Exiting because norm of weight gradient less than "
|
|
860
|
+
f"{min_nd:.12f}",
|
|
861
|
+
level="info",
|
|
862
|
+
color="green",
|
|
863
|
+
weight="bold",
|
|
864
|
+
)
|
|
865
|
+
# end if (iter > 1)
|
|
866
|
+
if config.do_newton and (metrics.iter == config.newt_start):
|
|
867
|
+
log("Starting Newton ... setting numdecs to 0", level="info", color="blue")
|
|
868
|
+
numdecs = 0
|
|
869
|
+
# call MPI_BCAST(leave,1,MPI_LOGICAL,0,seg_comm,ierr)
|
|
870
|
+
# call MPI_BCAST(startover,1,MPI_LOGICAL,0,seg_comm,ierr)
|
|
871
|
+
if leave:
|
|
872
|
+
c_end = time.time()
|
|
873
|
+
log(f"Finished in {c_end - c_start:.2f} seconds", level="info")
|
|
874
|
+
return state, LL
|
|
875
|
+
# else:
|
|
876
|
+
# !----- do accumulators: gm, alpha, mu, sbeta, rho, W
|
|
877
|
+
# the updated lrate & rholrate for the next iteration
|
|
878
|
+
metrics.lrate, metrics.rholrate, state, wc = update_params(
|
|
879
|
+
X=X,
|
|
880
|
+
iteration=metrics.iter,
|
|
881
|
+
config=config,
|
|
882
|
+
state=state,
|
|
883
|
+
accumulators=accumulators,
|
|
884
|
+
lrate=metrics.lrate,
|
|
885
|
+
rholrate=metrics.rholrate,
|
|
886
|
+
lrate0=metrics.lrate0,
|
|
887
|
+
rholrate0=metrics.rholrate0,
|
|
888
|
+
wc=wc,
|
|
889
|
+
newtrate=metrics.newtrate,
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
# !----- reject data
|
|
893
|
+
if config.do_reject:
|
|
894
|
+
raise NotImplementedError() # pragma: no cover
|
|
895
|
+
|
|
896
|
+
metrics.iter += 1
|
|
897
|
+
# end if/else
|
|
898
|
+
# end while
|
|
899
|
+
log(
|
|
900
|
+
"Maximum number of iterations reached before convergence."
|
|
901
|
+
" Consider increasing max_iter or relaxing tol.",
|
|
902
|
+
level="warning",
|
|
903
|
+
color="yellow",
|
|
904
|
+
weight="bold",
|
|
905
|
+
)
|
|
906
|
+
c_end = time.time()
|
|
907
|
+
log(f"Finished in {c_end - c_start:.2f} seconds", level="info")
|
|
908
|
+
return state, LL
|
|
909
|
+
|
|
910
|
+
|
|
911
|
+
def accum_updates_and_likelihood(
|
|
912
|
+
*,
|
|
913
|
+
X,
|
|
914
|
+
config,
|
|
915
|
+
accumulators,
|
|
916
|
+
state,
|
|
917
|
+
total_LL, # this is LLtmp in Fortran
|
|
918
|
+
iteration
|
|
919
|
+
):
|
|
920
|
+
"""Use accumulated arrays to updated logk and ndtmpsum."""
|
|
921
|
+
# !--- add to the cumulative dtmps
|
|
922
|
+
# ...
|
|
923
|
+
#--------------------------FORTRAN CODE-------------------------
|
|
924
|
+
# call MPI_REDUCE(dgm_numer_tmp,dgm_numer,num_models,MPI_DOUBLE_PRECISION,MPI_S...
|
|
925
|
+
# ...
|
|
926
|
+
# if update_A:
|
|
927
|
+
# call MPI_REDUCE(dWtmp,dA,nw*nw*num_models,MPI_DOUBLE_PRECISION,MPI_SUM,0,seg_co...
|
|
928
|
+
nw = config.n_components
|
|
929
|
+
Wtmp_working = torch.zeros(
|
|
930
|
+
(config.n_components, config.n_components),
|
|
931
|
+
dtype=config.dtype, device=config.device
|
|
932
|
+
)
|
|
933
|
+
# if (seg_rank == 0) then
|
|
934
|
+
if config.do_newton and iteration >= config.newt_start:
|
|
935
|
+
newton_terms = compute_newton_terms(
|
|
936
|
+
accumulators=accumulators, config=config, mu=state.mu
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
sigma2 = newton_terms["sigma2"]
|
|
940
|
+
kappa = newton_terms["kappa"]
|
|
941
|
+
lambda_ = newton_terms["lambda_"]
|
|
942
|
+
# if (print_debug) then
|
|
943
|
+
# end if (do_newton .and. iter >= newt_start)
|
|
944
|
+
|
|
945
|
+
#--------------------------FORTRAN CODE-------------------------
|
|
946
|
+
# if (print_debug) then
|
|
947
|
+
# print *, 'dA ', h, ' = '; call flush(6)
|
|
948
|
+
# call DSCAL(nw*nw,dble(-1.0)/dgm_numer(h),dA(:,:,h),1)
|
|
949
|
+
# dA(i,i,h) = dA(i,i,h) + dble(1.0)
|
|
950
|
+
#---------------------------------------------------------------
|
|
951
|
+
if config.do_reject:
|
|
952
|
+
raise NotImplementedError() # pragma: no cover
|
|
953
|
+
else:
|
|
954
|
+
accumulators.dA[:, :] *= -1.0 / accumulators.dgm_numer[0]
|
|
955
|
+
|
|
956
|
+
# basically the same as np.fill_diagonal where fill value is diag + 1.0
|
|
957
|
+
diag = accumulators.dA.diagonal()
|
|
958
|
+
idx = torch.arange(nw)
|
|
959
|
+
accumulators.dA[idx, idx] = diag + 1.0
|
|
960
|
+
# if (print_debug) then
|
|
961
|
+
|
|
962
|
+
if config.do_newton and iteration >= config.newt_start:
|
|
963
|
+
#--------------------------FORTRAN CODE-------------------------
|
|
964
|
+
# do i = 1,nw ... do k = 1,nw
|
|
965
|
+
# if (i == k) then
|
|
966
|
+
# Wtmp(i,i) = dA(i,i,h) / lambda(i,h)
|
|
967
|
+
# else
|
|
968
|
+
# sk1 = sigma2(i,h) * kappa(k,h)
|
|
969
|
+
# sk2 = sigma2(k,h) * kappa(i,h)
|
|
970
|
+
#---------------------------------------------------------------
|
|
971
|
+
# on-diagonal elements
|
|
972
|
+
diag = accumulators.dA.diagonal()
|
|
973
|
+
fill_values = diag / lambda_
|
|
974
|
+
idx = torch.arange(Wtmp_working.shape[0])
|
|
975
|
+
Wtmp_working[idx, idx] = fill_values
|
|
976
|
+
|
|
977
|
+
# off-diagonal elements
|
|
978
|
+
i_indices, k_indices = torch.meshgrid(
|
|
979
|
+
torch.arange(config.n_components, device=config.device),
|
|
980
|
+
torch.arange(config.n_components, device=config.device), indexing='ij',
|
|
981
|
+
)
|
|
982
|
+
off_diag_mask = i_indices != k_indices
|
|
983
|
+
sk1 = sigma2[i_indices] * kappa[k_indices]
|
|
984
|
+
sk2 = sigma2[k_indices] * kappa[i_indices]
|
|
985
|
+
positive_mask = (sk1 * sk2 > 0.0)
|
|
986
|
+
if torch.any(~positive_mask):
|
|
987
|
+
raise RuntimeError(
|
|
988
|
+
"Non-positive definite Hessian encountered in Newton update. "
|
|
989
|
+
f"Iteration {iteration}. Try setting do_newton to False."
|
|
990
|
+
)
|
|
991
|
+
condition_mask = positive_mask & off_diag_mask
|
|
992
|
+
if torch.any(condition_mask):
|
|
993
|
+
# # Wtmp(i,k) = (sk1*dA(i,k,h) - dA(k,i,h)) / (sk1*sk2 - dble(1.0))
|
|
994
|
+
numerator = (
|
|
995
|
+
sk1
|
|
996
|
+
* accumulators.dA[i_indices, k_indices]
|
|
997
|
+
- accumulators.dA[k_indices, i_indices]
|
|
998
|
+
)
|
|
999
|
+
denominator = sk1 * sk2 - 1.0
|
|
1000
|
+
Wtmp_working[condition_mask] = (numerator / denominator)[condition_mask]
|
|
1001
|
+
# end if (i == k)
|
|
1002
|
+
# end do (k)
|
|
1003
|
+
# end do (i)
|
|
1004
|
+
# end if (do_newton .and. iter >= newt_start)
|
|
1005
|
+
if ((not config.do_newton) or (iteration < config.newt_start)):
|
|
1006
|
+
# Wtmp = dA(:,:,h)
|
|
1007
|
+
assert Wtmp_working.shape == accumulators.dA.shape == (nw, nw)
|
|
1008
|
+
Wtmp_working = accumulators.dA.clone()
|
|
1009
|
+
assert Wtmp_working.shape == (nw, nw)
|
|
1010
|
+
#--------------------------FORTRAN CODE-------------------------
|
|
1011
|
+
# call DSCAL(nw*nw,dble(0.0),dA(:,:,h),1)
|
|
1012
|
+
# call DGEMM('N','N',nw,nw,nw,dble(1.0),A(:,comp_list(:,h)),nw,Wtmp,nw,dble...
|
|
1013
|
+
#---------------------------------------------------------------
|
|
1014
|
+
accumulators.dA[:, :] = 0.0
|
|
1015
|
+
accumulators.dA[:, :] += torch.matmul(state.A, Wtmp_working)
|
|
1016
|
+
|
|
1017
|
+
zeta = torch.zeros(config.n_components, dtype=config.dtype, device=config.device)
|
|
1018
|
+
#--------------------------FORTRAN CODE-------------------------
|
|
1019
|
+
# dAk(:,comp_list(i,h)) = dAk(:,comp_list(i,h)) + gm(h)*dA(:,i,h)
|
|
1020
|
+
# zeta(comp_list(i,h)) = zeta(comp_list(i,h)) + gm(h)
|
|
1021
|
+
#---------------------------------------------------------------
|
|
1022
|
+
source_columns = state.gm[0] * accumulators.dA
|
|
1023
|
+
accumulators.dAK[:, :] += source_columns
|
|
1024
|
+
zeta[:] += state.gm[0]
|
|
1025
|
+
|
|
1026
|
+
#--------------------------FORTRAN CODE-------------------------
|
|
1027
|
+
# dAk(:,k) = dAk(:,k) / zeta(k)
|
|
1028
|
+
# nd(iter,:) = sum(dAk*dAk,1)
|
|
1029
|
+
# ndtmpsum = sqrt(sum(nd(iter,:),mask=comp_used) / (nw*count(comp_used)))
|
|
1030
|
+
#---------------------------------------------------------------
|
|
1031
|
+
accumulators.dAK[:,:] /= zeta # Broadcasting division
|
|
1032
|
+
# nd is (num_iters, num_comps) in Fortran, but we only store current iteration
|
|
1033
|
+
nd = torch.sum(accumulators.dAK * accumulators.dAK, dim=0)
|
|
1034
|
+
assert nd.shape == (config.n_components,)
|
|
1035
|
+
|
|
1036
|
+
# comp_used should be a vector of True
|
|
1037
|
+
# In Fortran comp_used was based on component availability.
|
|
1038
|
+
# Unless identify_shared_comps was run. I have no plans to implement that.
|
|
1039
|
+
comp_used = torch.ones(config.n_components, dtype=bool)
|
|
1040
|
+
assert isinstance(comp_used, torch.Tensor)
|
|
1041
|
+
assert comp_used.shape == (config.n_components,)
|
|
1042
|
+
assert comp_used.dtype == torch.bool
|
|
1043
|
+
ndtmpsum = torch.sqrt(torch.sum(nd) / (nw * torch.count_nonzero(comp_used)))
|
|
1044
|
+
# end if (update_A)
|
|
1045
|
+
|
|
1046
|
+
# if (seg_rank == 0) then
|
|
1047
|
+
if config.do_reject:
|
|
1048
|
+
raise NotImplementedError() # pragma: no cover
|
|
1049
|
+
else:
|
|
1050
|
+
# LL(iter) = LLtmp2 / dble(all_blks*nw)
|
|
1051
|
+
# XXX: In the Fortran code LLtmp2 is the summed LLtmps across processes.
|
|
1052
|
+
likelihood = total_LL / (X.shape[0] * nw)
|
|
1053
|
+
return (likelihood, ndtmpsum)
|
|
1054
|
+
|
|
1055
|
+
|
|
1056
|
+
def update_params(
|
|
1057
|
+
*,
|
|
1058
|
+
X,
|
|
1059
|
+
iteration,
|
|
1060
|
+
config,
|
|
1061
|
+
state,
|
|
1062
|
+
accumulators,
|
|
1063
|
+
lrate,
|
|
1064
|
+
rholrate,
|
|
1065
|
+
lrate0,
|
|
1066
|
+
rholrate0,
|
|
1067
|
+
newtrate,
|
|
1068
|
+
wc,
|
|
1069
|
+
):
|
|
1070
|
+
"""Update learnable ICA Parameters, and learning rates."""
|
|
1071
|
+
# if (seg_rank == 0) then
|
|
1072
|
+
# if update_gm:
|
|
1073
|
+
if config.do_reject:
|
|
1074
|
+
raise NotImplementedError() # pragma: no cover
|
|
1075
|
+
# gm = dgm_numer / dble(numgoodsum)
|
|
1076
|
+
else:
|
|
1077
|
+
state.gm[:] = accumulators.dgm_numer / X.shape[0]
|
|
1078
|
+
# end if (update_gm)
|
|
1079
|
+
|
|
1080
|
+
# if update_alpha:
|
|
1081
|
+
# assert alpha.shape == (num_comps, num_mix)
|
|
1082
|
+
state.alpha[:, :] = accumulators.dalpha_numer / accumulators.dalpha_denom
|
|
1083
|
+
if torch.any(~torch.isfinite(state.alpha)):
|
|
1084
|
+
raise RuntimeError("Non-finite alpha encountered during update.")
|
|
1085
|
+
|
|
1086
|
+
# if update_c:
|
|
1087
|
+
# assert c.shape == (nw, num_models)
|
|
1088
|
+
state.c[:] = accumulators.dc_numer / accumulators.dc_denom
|
|
1089
|
+
if torch.any(~torch.isfinite(state.c)):
|
|
1090
|
+
raise RuntimeError("Non-finite c encountered during update.")
|
|
1091
|
+
|
|
1092
|
+
# === Section: Apply Parameter accumulators & Rescale ===
|
|
1093
|
+
# Apply accumulated statistics to update parameters, then rescale and refresh W/wc.
|
|
1094
|
+
# !print *, 'updating A ...'; call flush(6)
|
|
1095
|
+
if (iteration < share_start or (iteration % share_iter > 5)):
|
|
1096
|
+
if config.do_newton and (iteration >= config.newt_start):
|
|
1097
|
+
# lrate = min( newtrate, lrate + min(dble(1.0)/dble(newt_ramp),lrate) )
|
|
1098
|
+
# rholrate = rholrate0
|
|
1099
|
+
# call DAXPY(nw*num_comps,dble(-1.0)*lrate,dAk,1,A,1)
|
|
1100
|
+
lrate = min(newtrate, lrate + min(1.0 / config.newt_ramp, lrate))
|
|
1101
|
+
rholrate = rholrate0
|
|
1102
|
+
state.A -= lrate * accumulators.dAK
|
|
1103
|
+
else:
|
|
1104
|
+
lrate = min(lrate0, lrate + min(1 / config.newt_ramp, lrate))
|
|
1105
|
+
rholrate = rholrate0
|
|
1106
|
+
# call DAXPY(nw*num_comps,dble(-1.0)*lrate,dAk,1,A,1)
|
|
1107
|
+
state.A -= lrate * accumulators.dAK
|
|
1108
|
+
# end if do_newton
|
|
1109
|
+
# end if (update_A)
|
|
1110
|
+
|
|
1111
|
+
# if update_mu:
|
|
1112
|
+
state.mu += accumulators.dmu_numer / accumulators.dmu_denom
|
|
1113
|
+
if torch.any(~torch.isfinite(state.mu)):
|
|
1114
|
+
raise RuntimeError("Non-finite mu encountered during update.")
|
|
1115
|
+
|
|
1116
|
+
# if update_beta:
|
|
1117
|
+
state.sbeta *= torch.sqrt(accumulators.dbeta_numer / accumulators.dbeta_denom)
|
|
1118
|
+
sbetatmp = torch.minimum(torch.tensor(invsigmax), state.sbeta)
|
|
1119
|
+
state.sbeta = torch.maximum(torch.tensor(invsigmin), sbetatmp)
|
|
1120
|
+
if torch.any(~torch.isfinite(state.sbeta)):
|
|
1121
|
+
raise RuntimeError("Non-finite sbeta encountered during update.")
|
|
1122
|
+
|
|
1123
|
+
|
|
1124
|
+
state.rho += (
|
|
1125
|
+
rholrate
|
|
1126
|
+
* (
|
|
1127
|
+
1.0
|
|
1128
|
+
- (state.rho / torch.special.psi(1.0 + 1.0 / state.rho))
|
|
1129
|
+
* accumulators.drho_numer
|
|
1130
|
+
/ accumulators.drho_denom
|
|
1131
|
+
)
|
|
1132
|
+
)
|
|
1133
|
+
rhotmp = torch.minimum(torch.tensor(maxrho), state.rho) # shape (num_comps, num_mix)
|
|
1134
|
+
assert rhotmp.shape == (config.n_components, config.n_mixtures)
|
|
1135
|
+
state.rho = torch.maximum(torch.tensor(minrho), rhotmp)
|
|
1136
|
+
|
|
1137
|
+
# !--- rescale
|
|
1138
|
+
# !print *, 'rescaling A ...'; call flush(6)
|
|
1139
|
+
# from seed import A_FORTRAN
|
|
1140
|
+
if doscaling:
|
|
1141
|
+
# calculate the L2 norm for each column of A and then use it to normalize that
|
|
1142
|
+
# column and scale the corresponding columns in mu and sbeta, but only if the
|
|
1143
|
+
# norm is positive.
|
|
1144
|
+
Anrmk = torch.linalg.norm(state.A, dim=0)
|
|
1145
|
+
positive_mask = Anrmk > 0
|
|
1146
|
+
if positive_mask.all():
|
|
1147
|
+
state.A[:, positive_mask] /= Anrmk[positive_mask]
|
|
1148
|
+
state.mu[positive_mask, :] *= Anrmk[positive_mask, None]
|
|
1149
|
+
state.sbeta[positive_mask, :] /= Anrmk[positive_mask, None]
|
|
1150
|
+
else:
|
|
1151
|
+
raise NotImplementedError() # pragma: no cover
|
|
1152
|
+
# end if (doscaling)
|
|
1153
|
+
|
|
1154
|
+
if share_comps:
|
|
1155
|
+
raise NotImplementedError() # pragma: no cover
|
|
1156
|
+
|
|
1157
|
+
state.W, wc = get_unmixing_matrices(
|
|
1158
|
+
c=state.c,
|
|
1159
|
+
A=state.A,
|
|
1160
|
+
W=state.W,
|
|
1161
|
+
)
|
|
1162
|
+
# if (print_debug) then
|
|
1163
|
+
# call MPI_BCAST(gm,num_models,MPI_DOUBLE_PRECISION,0,seg_comm,ierr)
|
|
1164
|
+
# ...
|
|
1165
|
+
return lrate, rholrate, state, wc
|