amica-python 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
amica/core.py ADDED
@@ -0,0 +1,1165 @@
1
+ """Module containing amica funciton entry point."""
2
+
3
+ import time
4
+
5
+ import torch
6
+ from numpy.testing import assert_allclose
7
+
8
+ from amica._types import (
9
+ DataTensor2D,
10
+ )
11
+ from amica.constants import (
12
+ doscaling,
13
+ epsdble,
14
+ invsigmax,
15
+ invsigmin,
16
+ lratefact,
17
+ maxdecs,
18
+ maxincs,
19
+ maxrho,
20
+ mineig,
21
+ minlog,
22
+ minlrate,
23
+ minrho,
24
+ outstep,
25
+ rholratefact,
26
+ share_comps,
27
+ share_iter,
28
+ share_start,
29
+ use_grad_norm,
30
+ use_min_dll,
31
+ )
32
+ from amica.kernels import (
33
+ accumulate_alpha_stats,
34
+ accumulate_beta_stats,
35
+ accumulate_c_stats,
36
+ accumulate_kappa_stats,
37
+ accumulate_lambda_stats,
38
+ accumulate_mu_stats,
39
+ accumulate_rho_stats,
40
+ accumulate_sigma2_stats,
41
+ compute_mixture_responsibilities,
42
+ compute_model_loglikelihood_per_sample,
43
+ compute_model_responsibilities,
44
+ compute_preactivations,
45
+ compute_scaled_scores,
46
+ compute_source_densities,
47
+ compute_source_scores,
48
+ compute_total_loglikelihood_per_sample,
49
+ compute_weighted_responsibilities,
50
+ precompute_weighted_scores,
51
+ )
52
+ from amica.linalg import (
53
+ compute_sign_log_determinant,
54
+ get_initial_model_log_likelihood,
55
+ get_unmixing_matrices,
56
+ pre_whiten,
57
+ )
58
+ from amica.state import (
59
+ AmicaAccumulators,
60
+ AmicaConfig,
61
+ AmicaState,
62
+ IterationMetrics,
63
+ get_initial_state,
64
+ initialize_accumulators,
65
+ )
66
+
67
+ from ._batching import BatchLoader, choose_batch_size
68
+ from ._newton import compute_newton_terms
69
+ from .utils._logging import log, set_log_level
70
+ from .utils._progress import make_progress_bar
71
+ from .utils._verbose import _validate_verbose
72
+
73
+
74
+ def fit_amica(
75
+ X,
76
+ *,
77
+ whiten="zca",
78
+ mean_center=True,
79
+ n_components=None,
80
+ device="cpu",
81
+ n_mixtures=3,
82
+ max_iter=500,
83
+ tol=1e-7,
84
+ lrate=0.05,
85
+ rholrate=0.05,
86
+ pdftype=0,
87
+ do_newton=True,
88
+ newt_start=50,
89
+ newtrate=1.0,
90
+ newt_ramp=10,
91
+ batch_size=None,
92
+ w_init=None,
93
+ sbeta_init=None,
94
+ mu_init=None,
95
+ do_reject=False,
96
+ random_state=None,
97
+ verbose=1,
98
+ ):
99
+ """Perform Adaptive Mixture Independent Component Analysis (AMICA).
100
+
101
+ Implements the AMICA algorithm as described in :footcite:t:`palmer2012` and
102
+ :footcite:t:`palmer2008`, and originally implemented in :footcite:t:`amica`.
103
+
104
+ Parameters
105
+ ----------
106
+ X : array-like, shape (``n_samples``, ``n_features``)
107
+ Training data, where ``n_samples`` is the number of samples and
108
+ ``n_features`` is the number of features.
109
+ n_components : int, optional
110
+ Number of components to extract. If ``None`` (default), set to ``n_features``.
111
+ Note that the number of components may be reduced during whitening if the data
112
+ are rank-deficient.
113
+ n_mixtures: int, optional, default=3
114
+ Number of mixtures components to use in the Gaussian Mixture Model (GMM) for
115
+ each component's source density. default is ``3``.
116
+ batch_size : int, optional
117
+ Batch size for processing data in chunks along the samples axis. If ``None``,
118
+ the batch size is chosen automatically to keep peak memory under ~1.5 GB, and
119
+ warns if the batch size is below ~8k samples. If the input data is small enough
120
+ to process in one shot, no batching is used. If you want to enforce no
121
+ batching, you can override this memory cap by setting batch_size explicitly,
122
+ e.g. to `X.shape[0]` to process all samples at once. but note that this may
123
+ lead to high memory usage for large datasets.
124
+ device : str, optional
125
+ Device to run the computations on. Can be either 'cpu' or 'cuda' for GPU
126
+ acceleration. Note that using 'cuda' requires a compatible NVIDIA GPU and
127
+ the appropriate CUDA drivers installed.
128
+ whiten : str {"zca", "pca", "variance"}
129
+ Whitening method to apply to the data before fitting AMICA. Options are:
130
+ - "zca": Zero-phase component analysis (ZCA) whitening.
131
+ - "pca": Principal component analysis (PCA) whitening.
132
+ - "variance": Only variance normalization of the features is done (no sphering).
133
+ mean_center : bool, optional
134
+ If ``True``, X is mean corrected.
135
+ max_iter : int, optional
136
+ Maximum number of iterations to perform. Default is ``500``.
137
+ random_state : int or None, optional (default=None)
138
+ Used to perform a random initialization when w_init is not provided.
139
+ If int, random_state is the seed used by the random number generator during
140
+ whitening, and is used to set the seed during optimization initialization.
141
+ w_init : array-like, shape (``n_components``, ``n_components``), optional
142
+ Initial weights for the mixture components. If None, weights are initialized
143
+ randomly. This is meant to be used for testing and debugging purposes only.
144
+ sbeta_init : array-like, shape (``n_components``, ``n_mixtures``), optional
145
+ Initial scales (sbeta) for the mixture components. If None, scales are
146
+ initialized randomly. This is meant to be used for testing and debugging
147
+ purposes only.
148
+ mu_init : array-like, shape (``n_components``, ``n_mixtures``), optional
149
+ Initial locations (mu) for the mixture components. If None, locations are
150
+ initialized randomly. This is meant to be used for testing and debugging
151
+ purposes only.
152
+ lrate : float, default=0.05
153
+ Initial learning rate for the natural gradient.
154
+ rholrate : float = default=0.05
155
+ initial learning rate for shape parameters.
156
+ pdftype : int, default=0
157
+ Type of source density model to use. Currently only ``0`` is supported,
158
+ which corresponds to the Gaussian Mixture Model (GMM) density.
159
+ do_newton : bool, default=True
160
+ If ``True``, the optimization method will switch from Stochastic Gradient
161
+ Descent (SGD) to newton updates after ``newt_start`` iterations. If ``False``,
162
+ only SGD updates are used.
163
+ newt_start : int, default=50
164
+ Number of iterations before switching to Newton updates if ``do_newton`` is
165
+ ``True``.
166
+ newtrate : float, default=1.0
167
+ learning rate for newton iterations.
168
+ verbose : int, default=1
169
+ Output mode during optimization:
170
+
171
+ - ``0``: silent
172
+ - ``1``: progress bar
173
+ - ``2``: per-iteration FORTRAN-style logs
174
+
175
+ Returns
176
+ -------
177
+ results : dict
178
+ Dictionary containing the following entries:
179
+
180
+ - mean : array, shape (``n_features``,) | ``None``
181
+ The mean over features. if ``do_mean=False``, this is ``None``.
182
+ - S : array, shape (``n_components``, ``n_features``)
183
+ The sphering (whitening) matrix applied to the data.
184
+ - W : array, shape (``n_components``, ``n_components``)
185
+ The unmixing matrix.
186
+ - A : array, shape (``n_components``, ``n_components``)
187
+ The mixing matrix in the space of sphered data. To get the mixing matrix
188
+ in the original data space, use ``np.linalg.pinv(S) @ A``.
189
+ - LL : array, shape (``max_iter``,)
190
+ The log-likelihood values at each iteration. If the algorithm converged
191
+ before reaching ``max_iter``, the remaining entries will be zero.
192
+ - gm : array, shape (1,)
193
+ The Gaussian mixture model weights. Since only one model is supported,
194
+ this will be of shape (1,).
195
+ - mu : array, shape (``n_components``, ``n_mixtures``)
196
+ The location parameters for the mixture components, i.e. the means of the
197
+ mixture components.
198
+ - rho : array, shape (``n_components``, ``n_mixtures``)
199
+ The shape parameters for the mixture components.
200
+ - sbeta : array, shape (``n_components``, ``n_mixtures``)
201
+ The scale (precision) parameters for the mixture components.
202
+ - alpha : array, shape (``n_components``, ``n_mixtures``)
203
+ The mixture weights for the mixture components.
204
+ - c : array, shape (``n_components``,)
205
+ The model bias terms.
206
+
207
+ Notes
208
+ -----
209
+ In Fortran AMICA, ``alpha``, ``sbeta``, ``mu``, and ``rho`` are of shape
210
+ (``n_mixtures``, ``n_components``) (transposed compared to here).
211
+
212
+ References
213
+ ----------
214
+ .. footbibliography::
215
+
216
+ """
217
+ verbose = _validate_verbose(verbose)
218
+ set_log_level("INFO" if verbose == 2 else "ERROR")
219
+
220
+ if batch_size is None:
221
+ batch_size = choose_batch_size(
222
+ N=X.shape[0],
223
+ n_comps=n_components if n_components is not None else X.shape[1],
224
+ n_mix=n_mixtures,
225
+ )
226
+ # Step 1: Create config and state objects (new dataclass approach)
227
+ config = AmicaConfig(
228
+ n_features=X.shape[1], # Number of channels (corrected from X.shape[1])
229
+ n_components=n_components if n_components is not None else X.shape[1],
230
+ n_models=1,
231
+ n_mixtures=n_mixtures,
232
+ max_iter=max_iter,
233
+ batch_size=batch_size,
234
+ device=torch.device(device),
235
+ pdftype=pdftype,
236
+ tol=tol,
237
+ lrate=lrate,
238
+ rholrate=rholrate,
239
+ do_newton=do_newton,
240
+ newt_start=newt_start,
241
+ newtrate=newtrate,
242
+ newt_ramp=newt_ramp,
243
+ do_reject=do_reject,
244
+ verbose=verbose,
245
+ )
246
+
247
+ # Step 2: Create initial state (this will eventually replace manual initialization)
248
+ torch.set_default_dtype(config.dtype) # TODO: Make this less global
249
+ state = get_initial_state(config)
250
+
251
+ # Init
252
+ if config.do_reject:
253
+ raise NotImplementedError(
254
+ "Sample rejection by log likelihood is not yet supported."
255
+ ) # pragma: no cover
256
+ dataseg = X.copy()
257
+
258
+ # Whitening
259
+ do_sphere = True if whiten in {"zca", "pca"} else False
260
+ do_approx_sphere = True if whiten == "zca" else False
261
+ do_mean = True if mean_center else False
262
+ dataseg, whitening_matrix, sldet, whitening_inverse, mean = pre_whiten(
263
+ X=dataseg,
264
+ n_components=n_components,
265
+ mineig=mineig,
266
+ do_mean=do_mean,
267
+ do_sphere=do_sphere,
268
+ do_approx_sphere=do_approx_sphere,
269
+ inplace=True,
270
+ )
271
+
272
+ # Run AMICA
273
+ state_dict, LL = solve(
274
+ X=dataseg,
275
+ config=config,
276
+ state=state,
277
+ sldet=sldet,
278
+ random_state=random_state,
279
+ initial_weights=w_init,
280
+ initial_scales=sbeta_init,
281
+ initial_locations=mu_init,
282
+ )
283
+
284
+ return dict(
285
+ S=whitening_matrix,
286
+ mean=mean,
287
+ gm=state_dict["gm"],
288
+ mu=state_dict["mu"],
289
+ rho=state_dict["rho"],
290
+ sbeta=state_dict["sbeta"],
291
+ W=state_dict["W"],
292
+ A=state_dict["A"],
293
+ c=state_dict["c"],
294
+ alpha=state_dict["alpha"],
295
+ LL=LL,
296
+ )
297
+
298
+ def solve(
299
+ X,
300
+ *,
301
+ config,
302
+ state,
303
+ sldet,
304
+ random_state=None,
305
+ initial_weights=None,
306
+ initial_scales=None,
307
+ initial_locations=None,
308
+ ):
309
+ """Run the AMICA algorithm.
310
+
311
+ Parameters
312
+ ----------
313
+ X : array, shape (N, T)
314
+ Matrix containing the features that have to be unmixed. N is the
315
+ number of features, T is the number of samples. X has to be centered
316
+ initial_weights : array-like, shape (n_components, n_components), optional
317
+ Initial weights for the mixture components. If None, weights are initialized
318
+ randomly. This is meant to be used for testing and debugging purposes only.
319
+ initial_scales : array-like, shape (n_components, n_mixtures), optional
320
+ Initial scales (sbeta) for the mixture components. If None, scales are
321
+ initialized randomly. This is meant to be used for testing and debugging
322
+ purposes only.
323
+ initial_locations : array-like, shape (n_components, n_mixtures), optional
324
+ Initial locations (mu) for the mixture components. If None, locations are
325
+ initialized randomly. This is meant to be used for testing and debugging
326
+ purposes only.
327
+ """
328
+ # No-copy (if on CPU)
329
+ X: DataTensor2D = torch.as_tensor(X, dtype=config.dtype, device=config.device)
330
+ rng = torch.Generator()
331
+ if random_state is not None:
332
+ rng.manual_seed(random_state)
333
+ # The API will use n_components but under the hood we'll match the Fortran naming
334
+ # TODO: Maybe rename n_components to num_comps in the config dataclass?
335
+ num_comps = config.n_components
336
+ num_mix = config.n_mixtures
337
+ # !-------------------- ALLOCATE VARIABLES ---------------------
338
+
339
+ # !------------------- INITIALIZE VARIABLES ----------------------
340
+ # print *, myrank+1, ': Initializing variables ...'; call flush(6);
341
+ # if (seg_rank == 0) then
342
+
343
+ assert_allclose(state.gm.sum(), 1.0)
344
+ # load_alpha:
345
+ state.alpha[:, :num_mix] = 1.0 / num_mix
346
+ # load_mu:
347
+ mu_values = torch.arange(num_mix) - (num_mix - 1) / 2
348
+ state.mu[:, :] = mu_values[None, :]
349
+ if initial_locations is None:
350
+ initial_locations = torch.rand(num_comps, num_mix, generator=rng)
351
+ else:
352
+ assert initial_locations.shape == (num_comps, num_mix)
353
+ initial_locations = torch.as_tensor(initial_locations, dtype=torch.float64)
354
+ state.mu = state.mu + 0.05 * (1.0 - 2.0 * initial_locations)
355
+ # load_beta:
356
+ if initial_scales is None:
357
+ initial_scales = torch.rand(num_comps, num_mix, generator=rng)
358
+ else:
359
+ assert initial_scales.shape == (num_comps, num_mix)
360
+ initial_scales = torch.as_tensor(initial_scales, dtype=torch.float64)
361
+ state.sbeta = 1.0 + 0.1 * (0.5 - initial_scales)
362
+ # load_c:
363
+ state.c.fill_(0.0)
364
+
365
+ # load_A:
366
+ if initial_weights is None:
367
+ initial_weights = torch.rand(num_comps, num_comps, generator=rng)
368
+ else:
369
+ assert initial_weights.shape == (num_comps, num_comps)
370
+ initial_weights = torch.as_tensor(initial_weights, dtype=torch.float64)
371
+
372
+ state.A[:, :] = 0.01 * (0.5 - initial_weights)
373
+ idx = torch.arange(num_comps)
374
+ state.A[idx, idx] = 1.0
375
+ Anrmk = torch.linalg.norm(state.A[:, :], dim=0)
376
+ state.A[:, :] /= Anrmk
377
+ # end load_A
378
+
379
+ W, wc = get_unmixing_matrices(
380
+ c=state.c,
381
+ A=state.A,
382
+ W=state.W,
383
+ )
384
+ assert W.dtype == torch.float64
385
+ state.W = W.clone()
386
+ del W # safe guard against accidental use of W instead of state.W
387
+
388
+
389
+ # !-------------------- Determine optimal block size -------------------
390
+ log(f"1: block size = {config.batch_size}", level="info", color=None)
391
+
392
+ # !XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX main loop XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
393
+ log(
394
+ "Solving. (please be patient, this may take a while)...",
395
+ level="info",
396
+ color="blue",
397
+ weight="bold"
398
+ )
399
+ with torch.no_grad():
400
+ state, LL = optimize(
401
+ X=X,
402
+ sldet=sldet.item(),
403
+ wc=wc,
404
+ config=config,
405
+ state=state,
406
+ )
407
+ # Convert Tensors to numpy arrays for output
408
+ state_dict = state.to_numpy()
409
+ LL = LL.cpu().numpy()
410
+ return state_dict, LL
411
+
412
+
413
+ def optimize(
414
+ *,
415
+ X: DataTensor2D,
416
+ sldet: float,
417
+ wc: torch.Tensor,
418
+ config: AmicaConfig,
419
+ state: AmicaState,
420
+ ):
421
+ """Optimize the learnable Paramters."""
422
+ # Just set all convergence creterion to the user specific tol
423
+ min_dll = config.tol
424
+ min_nd = config.tol
425
+
426
+ # These variables can be updated in the loop
427
+ leave = False
428
+ do_newton = config.do_newton
429
+ numdecs = 0 # number of consecutive iterations where LL decreased from previous
430
+ numincs = 0 # number of consecutive iterations where LL increased by less than tol
431
+ metrics = IterationMetrics(
432
+ iter=1,
433
+ lrate=config.lrate,
434
+ rholrate=config.rholrate,
435
+ lrate0=config.lrate, # updates slower than lrate..
436
+ rholrate0=config.rholrate, # Updates slower than rholrate..
437
+ newtrate=config.newtrate,
438
+ )
439
+
440
+ # Initialize accumulators container
441
+ accumulators = initialize_accumulators(config)
442
+ if config.device.type != "cpu":
443
+ state.to_device(device=config.device)
444
+ wc = wc.to(device=config.device)
445
+ # We allocate these separately.
446
+ Dsum = torch.tensor(0.0, dtype=torch.float64, device=config.device)
447
+ Dsign = torch.tensor(1.0, dtype=torch.float64, device=config.device)
448
+ # per sample loglik
449
+ loglik = torch.zeros((X.shape[0],), dtype=torch.float64, device=config.device)
450
+ # likelihood history
451
+ LL = torch.zeros(max(1, config.max_iter), dtype=torch.float64, device=config.device)
452
+
453
+ c_start = time.time()
454
+ c1 = time.time()
455
+ progress = None
456
+ task_id = None
457
+ if config.verbose == 1:
458
+ progress, task_id = make_progress_bar(
459
+ total=config.max_iter,
460
+ lrate=metrics.lrate,
461
+ )
462
+ try:
463
+ return _main_loop(
464
+ X=X,
465
+ sldet=sldet,
466
+ wc=wc,
467
+ config=config,
468
+ state=state,
469
+ do_newton=do_newton,
470
+ leave=leave,
471
+ numdecs=numdecs,
472
+ numincs=numincs,
473
+ metrics=metrics,
474
+ accumulators=accumulators,
475
+ Dsum=Dsum,
476
+ Dsign=Dsign,
477
+ loglik=loglik,
478
+ LL=LL,
479
+ c_start=c_start,
480
+ c1=c1,
481
+ progress=progress,
482
+ task_id=task_id,
483
+ min_dll=min_dll,
484
+ min_nd=min_nd,
485
+ )
486
+ finally:
487
+ if progress is not None:
488
+ progress.stop()
489
+
490
+
491
+ def _main_loop(
492
+ *,
493
+ X: DataTensor2D,
494
+ sldet: float,
495
+ wc: torch.Tensor,
496
+ config: AmicaConfig,
497
+ state: AmicaState,
498
+ do_newton: bool,
499
+ leave: bool,
500
+ numdecs: int,
501
+ numincs: int,
502
+ metrics: IterationMetrics,
503
+ accumulators: AmicaAccumulators,
504
+ Dsum: torch.Tensor,
505
+ Dsign: torch.Tensor,
506
+ loglik: torch.Tensor,
507
+ LL: torch.Tensor,
508
+ c_start: float,
509
+ c1: float,
510
+ progress,
511
+ task_id,
512
+ min_dll: float,
513
+ min_nd: float,
514
+ ):
515
+ """Run the AMICA optimization loop and return updated state and LL history."""
516
+ while metrics.iter <= config.max_iter:
517
+ accumulators.reset()
518
+ loglik.fill_(0.0)
519
+ doing_newton = do_newton and (metrics.iter >= config.newt_start)
520
+ # !----- get determinants
521
+ # The Fortran code computed log|det(W)| indirectly via QR factorization
522
+ # We use slogdet on the original unmixing matrix to get sign and log|det|
523
+ _, Dsum = compute_sign_log_determinant(
524
+ unmixing_matrix=state.W,
525
+ minlog=minlog,
526
+ )
527
+
528
+ if config.do_reject:
529
+ raise NotImplementedError() # pragma: no cover
530
+ # !--------- loop over the blocks ----------
531
+ '''
532
+ # In Fortran, the OMP parallel region would start before the lines below.
533
+ # !$OMP PARALLEL DEFAULT(SHARED) &
534
+ # ...
535
+ # !print *, myrank+1, thrdnum+1, ': Inside openmp code ... '; call flush(6)
536
+ '''
537
+
538
+ # -- 0. Baseline terms for per-sample model log-likelihood --
539
+ initial = get_initial_model_log_likelihood(
540
+ unmixing_logdet=Dsum,
541
+ whitening_logdet=sldet,
542
+ model_weight=state.gm[0],
543
+ )
544
+
545
+ #=============================== Subsection ====================================
546
+ # === Begin chunk loop ===
547
+ # ==============================================================================
548
+ batch_loader = BatchLoader(X, axis=0, batch_size=config.batch_size)
549
+ for batch_idx, (data_batch, batch_indices) in enumerate(batch_loader):
550
+
551
+ # ======================================================================
552
+ # Expectation Step (E-step)
553
+ # ======================================================================
554
+
555
+ # 1. --- Compute source pre-activations
556
+ # !--- get b
557
+ if not state.W.device.type == data_batch.device.type:
558
+ raise ValueError(
559
+ f"Mismatch between state.W device ({state.W.device}) "
560
+ "and data_batch device ({data_batch.device})"
561
+ )
562
+ b = compute_preactivations(
563
+ X=data_batch,
564
+ unmixing_matrix=state.W,
565
+ bias=wc,
566
+ do_reject=config.do_reject,
567
+ n_weights=config.n_components,
568
+ )
569
+ # 2. --- Source densities, and per-sample mixture log-densities (logits)
570
+ y, z = compute_source_densities(
571
+ pdftype=config.pdftype,
572
+ b=b,
573
+ sbeta=state.sbeta,
574
+ mu=state.mu,
575
+ alpha=state.alpha,
576
+ rho=state.rho,
577
+ )
578
+ z0 = z # log densities (alias for clarity with Fortran code)
579
+
580
+ # 3. --- Aggregate mixture logits into per-sample model log likelihoods
581
+ modloglik = torch.full(
582
+ size=(data_batch.shape[0], 1),
583
+ fill_value=initial,
584
+ dtype=config.dtype,
585
+ device=config.device,
586
+ )
587
+ compute_model_loglikelihood_per_sample(
588
+ log_densities=z0,
589
+ out_modloglik=modloglik[:, 0],
590
+ )
591
+
592
+ # 4. -- Responsibilities within each component ---
593
+ # !--- get normalized z
594
+ z = compute_mixture_responsibilities(log_densities=z0, inplace=True)
595
+ z0 = None
596
+ del z0 # guard against use of stale name. z owns that memory
597
+
598
+ # 5. --- Across-model Responsibilities and Total Log-Likelihood ---
599
+ loglik[batch_indices] = compute_total_loglikelihood_per_sample(
600
+ modloglik=modloglik,
601
+ out_loglik=loglik[batch_indices]
602
+ )
603
+
604
+ if config.do_reject:
605
+ raise NotImplementedError() # pragma: no cover
606
+ else:
607
+ # 6. --- Responsibilities for each model ---
608
+ v = compute_model_responsibilities(
609
+ modloglik=modloglik,
610
+ out=modloglik, # reuse modloglik memory
611
+ )
612
+ modloglik = None
613
+ del modloglik # Guard. v owns that memory now
614
+
615
+ # ================================ M-STEP ==================================
616
+ # === Maximization-step: Parameter accumulators ===
617
+ # - Update parameters based on current responsibilities
618
+ # - Update unmixing matrices with gradient ascent and Newton-Raphson
619
+ # ==========================================================================
620
+
621
+ # !--- get g, u, ufp
622
+ #--------------------------FORTRAN CODE-------------------------
623
+ # vsum = sum( v(bstrt:bstp,h) )
624
+ # dgm_numer_tmp(h) = dgm_numer_tmp(h) + vsum
625
+ #---------------------------------------------------------------
626
+ model_resps = v[:, 0] # select responsibilities for this model
627
+ vsum = model_resps.sum()
628
+
629
+ # NOTE: u is a view of z, so changes to u will affect z (and vice versa)
630
+ u = compute_weighted_responsibilities(
631
+ mixture_responsibilities=z,
632
+ model_responsibilities=model_resps,
633
+ single_model=True,
634
+ )
635
+ z = None
636
+ del z # guard against use of stale name. u owns that memory now
637
+ usum = u.sum(dim=0) # shape: (nw, num_mix)
638
+
639
+ fp = compute_source_scores(
640
+ pdftype=config.pdftype,
641
+ y=y,
642
+ rho=state.rho,
643
+ )
644
+
645
+ # For SGD, fp only exists to get ufp. Lets overwrite it to save memory.
646
+ ufp = precompute_weighted_scores(
647
+ weighted_responsibilities=u,
648
+ scores=fp,
649
+ out_ufp=fp if not doing_newton else None,
650
+ )
651
+ if not doing_newton:
652
+ fp = None
653
+ del fp # End of life. ufp owns that memory now
654
+
655
+ g = compute_scaled_scores(
656
+ weighted_scores=ufp,
657
+ scales=state.sbeta,
658
+ )
659
+
660
+ # --- Stochastic Gradient Descent accumulators ---
661
+ # gm (model weights)
662
+ accumulators.dgm_numer[0] += vsum
663
+ # c (bias)
664
+ accumulate_c_stats(
665
+ X=data_batch,
666
+ model_responsibilities=model_resps,
667
+ vsum=vsum,
668
+ n_weights=config.n_components,
669
+ out_numer=accumulators.dc_numer,
670
+ out_denom=accumulators.dc_denom,
671
+ )
672
+ # Alpha (mixture weights)
673
+ accumulate_alpha_stats(
674
+ usum=usum,
675
+ vsum=vsum,
676
+ out_numer=accumulators.dalpha_numer,
677
+ out_denom=accumulators.dalpha_denom,
678
+ )
679
+ # Mu (location)
680
+ accumulate_mu_stats(
681
+ ufp=ufp,
682
+ y=y,
683
+ sbeta=state.sbeta,
684
+ rho=state.rho,
685
+ out_numer=accumulators.dmu_numer,
686
+ out_denom=accumulators.dmu_denom,
687
+ )
688
+ # Beta (scale/precision)
689
+ accumulate_beta_stats(
690
+ usum=usum,
691
+ rho=state.rho,
692
+ ufp=ufp,
693
+ y=y,
694
+ out_numer=accumulators.dbeta_numer,
695
+ out_denom=accumulators.dbeta_denom,
696
+ )
697
+ # Rho (shape parameter of generalized Gaussian)
698
+ accumulate_rho_stats(
699
+ y=y,
700
+ rho=state.rho,
701
+ u=u,
702
+ usum=usum,
703
+ epsdble=epsdble,
704
+ out_numer=accumulators.drho_numer,
705
+ out_denom=accumulators.drho_denom,
706
+ )
707
+ # --- Newton-Raphson accumulators ---
708
+ if do_newton and metrics.iter >= config.newt_start:
709
+ # NOTE: Fortran computes dsigma_* for all iters, but its unnecessary
710
+ # Sigma^2 accumulators (noise variance)
711
+ accumulate_sigma2_stats(
712
+ model_responsibilities=model_resps,
713
+ source_estimates=b,
714
+ vsum=vsum,
715
+ out_numer=accumulators.newton.dsigma2_numer,
716
+ out_denom=accumulators.newton.dsigma2_denom,
717
+ )
718
+ # Kappa accumulators (curvature terms for A)
719
+ accumulate_kappa_stats(
720
+ ufp=ufp,
721
+ fp=fp,
722
+ sbeta=state.sbeta,
723
+ usum=usum,
724
+ out_numer=accumulators.newton.dkappa_numer,
725
+ out_denom=accumulators.newton.dkappa_denom,
726
+ )
727
+ # Lambda accumulators (nonlinearity shape parameter)
728
+ accumulate_lambda_stats(
729
+ fp=fp,
730
+ y=y,
731
+ u=u,
732
+ usum=usum,
733
+ out_numer=accumulators.newton.dlambda_numer,
734
+ out_denom=accumulators.newton.dlambda_denom,
735
+ )
736
+ # (dbar)Alpha accumulators
737
+ accumulators.newton.dbaralpha_numer[:, :] += usum
738
+ accumulators.newton.dbaralpha_denom[:, :] += vsum
739
+ # end if (do_newton and iteration >= newt_start)
740
+
741
+ # if (print_debug .and. (blk == 1) .and. (thrdnum == 0)) then
742
+ # if update_A:
743
+ #--------------------------FORTRAN CODE--------------------------------
744
+ # call DSCAL(nw*nw,dble(0.0),Wtmp2(:,:,thrdnum+1),1)
745
+ # call DGEMM('T','N',nw,nw,tblksize,dble(1.0),g(bstrt:bstp,:),...
746
+ # dble(1.0),Wtmp2(:,:,thrdnum+1),nw)
747
+ # call DAXPY(nw*nw,dble(1.0),Wtmp2(:,:,thrdnum+1),1,dWtmp(:,:,h),1)
748
+ #----------------------------------------------------------------------
749
+ accumulators.dA[:, :] += torch.matmul(g.T, b)
750
+ # end do (blk)'
751
+
752
+ # In Fortran, the OMP parallel region is closed here
753
+ # !$OMP END PARALLEL
754
+
755
+ # End of these lifetimes
756
+ del b, g, u, ufp, usum, vsum, v, model_resps, y
757
+ if doing_newton:
758
+ del fp # already deleted if not doing_newton
759
+
760
+ likelihood, ndtmpsum = accum_updates_and_likelihood(
761
+ X=X,
762
+ config=config,
763
+ accumulators=accumulators,
764
+ state=state,
765
+ total_LL=loglik.sum(),
766
+ iteration=metrics.iter
767
+ )
768
+ metrics.loglik = likelihood
769
+ metrics.ndtmpsum = ndtmpsum
770
+ # return accumulators, metrics
771
+
772
+ # ==============================================================================
773
+ ndtmpsum = metrics.ndtmpsum
774
+ LL[metrics.iter - 1] = metrics.loglik
775
+
776
+ # !----- display log likelihood of data
777
+ # if (seg_rank == 0) then
778
+ c2 = time.time()
779
+ t0 = c2 - c1
780
+ # if (mod(iter,outstep) == 0) then
781
+
782
+ if progress is not None and task_id is not None:
783
+ progress.update(
784
+ task_id,
785
+ completed=metrics.iter,
786
+ ll=f"{float(LL[metrics.iter - 1]):.4f}",
787
+ nd=f"{float(ndtmpsum):.4f}",
788
+ lrate=f"{metrics.lrate:.5f}",
789
+ )
790
+
791
+ if config.verbose == 2 and (metrics.iter % outstep) == 0:
792
+ report = (
793
+ f"Iteration {metrics.iter}, "
794
+ f"lrate = {metrics.lrate:.5f}, "
795
+ f"LL = {LL[metrics.iter - 1]:.7f}, "
796
+ f"nd = {ndtmpsum:.7f}, D = {float(Dsum):.5f} "
797
+ f"took {t0:.2f} seconds"
798
+ )
799
+ log(msg=report, level="info", color=None)
800
+ c1 = time.time()
801
+
802
+ # !----- check whether likelihood is increasing
803
+ # if (seg_rank == 0) then
804
+ # ! if we get a NaN early, try to reinitialize and startover a few times
805
+ if torch.isnan(LL[metrics.iter - 1]):
806
+ raise RuntimeError(f"Log Likelihood is NaN at iteration {metrics.iter}")
807
+ # end if
808
+ if metrics.iter > 1:
809
+ if (LL[metrics.iter - 1] < LL[metrics.iter - 2]):
810
+ # assert 1 == 0
811
+ log("Likelihood decreasing!", level="warning", color="yellow")
812
+ if (metrics.lrate < minlrate) or (ndtmpsum <= min_nd):
813
+ leave = True
814
+ log(
815
+ "minimum change threshold met, exiting loop",
816
+ level="info",
817
+ color="green",
818
+ weight="bold"
819
+ )
820
+ else:
821
+ metrics.lrate *= lratefact
822
+ metrics.rholrate *= rholratefact
823
+ numdecs += 1
824
+ if numdecs >= maxdecs:
825
+ metrics.lrate0 *= lratefact
826
+ if metrics.iter > config.newt_start:
827
+ metrics.rholrate0 *= rholratefact
828
+ if config.do_newton and metrics.iter > config.newt_start:
829
+ log(
830
+ "Reducing maximum Newton lrate",
831
+ level="info",
832
+ color="blue"
833
+ )
834
+ metrics.newtrate *= lratefact
835
+ numdecs = 0
836
+ # end if (numdecs >= maxdecs)
837
+ # end if (lrate vs minlrate)
838
+ # end if LL
839
+ if use_min_dll:
840
+ if (LL[metrics.iter - 1] - LL[metrics.iter - 2]) < min_dll:
841
+ numincs += 1
842
+ if numincs > maxincs:
843
+ leave = True
844
+ log(
845
+ "Exiting because likelihood increasing by less than "
846
+ f"{min_dll} for more than {maxincs} iterations ...",
847
+ level="info",
848
+ color="green",
849
+ weight="bold"
850
+ )
851
+ else:
852
+ numincs = 0
853
+ else:
854
+ raise NotImplementedError() # pragma: no cover
855
+ if use_grad_norm:
856
+ if ndtmpsum < min_nd:
857
+ leave = True
858
+ log(
859
+ "Exiting because norm of weight gradient less than "
860
+ f"{min_nd:.12f}",
861
+ level="info",
862
+ color="green",
863
+ weight="bold",
864
+ )
865
+ # end if (iter > 1)
866
+ if config.do_newton and (metrics.iter == config.newt_start):
867
+ log("Starting Newton ... setting numdecs to 0", level="info", color="blue")
868
+ numdecs = 0
869
+ # call MPI_BCAST(leave,1,MPI_LOGICAL,0,seg_comm,ierr)
870
+ # call MPI_BCAST(startover,1,MPI_LOGICAL,0,seg_comm,ierr)
871
+ if leave:
872
+ c_end = time.time()
873
+ log(f"Finished in {c_end - c_start:.2f} seconds", level="info")
874
+ return state, LL
875
+ # else:
876
+ # !----- do accumulators: gm, alpha, mu, sbeta, rho, W
877
+ # the updated lrate & rholrate for the next iteration
878
+ metrics.lrate, metrics.rholrate, state, wc = update_params(
879
+ X=X,
880
+ iteration=metrics.iter,
881
+ config=config,
882
+ state=state,
883
+ accumulators=accumulators,
884
+ lrate=metrics.lrate,
885
+ rholrate=metrics.rholrate,
886
+ lrate0=metrics.lrate0,
887
+ rholrate0=metrics.rholrate0,
888
+ wc=wc,
889
+ newtrate=metrics.newtrate,
890
+ )
891
+
892
+ # !----- reject data
893
+ if config.do_reject:
894
+ raise NotImplementedError() # pragma: no cover
895
+
896
+ metrics.iter += 1
897
+ # end if/else
898
+ # end while
899
+ log(
900
+ "Maximum number of iterations reached before convergence."
901
+ " Consider increasing max_iter or relaxing tol.",
902
+ level="warning",
903
+ color="yellow",
904
+ weight="bold",
905
+ )
906
+ c_end = time.time()
907
+ log(f"Finished in {c_end - c_start:.2f} seconds", level="info")
908
+ return state, LL
909
+
910
+
911
+ def accum_updates_and_likelihood(
912
+ *,
913
+ X,
914
+ config,
915
+ accumulators,
916
+ state,
917
+ total_LL, # this is LLtmp in Fortran
918
+ iteration
919
+ ):
920
+ """Use accumulated arrays to updated logk and ndtmpsum."""
921
+ # !--- add to the cumulative dtmps
922
+ # ...
923
+ #--------------------------FORTRAN CODE-------------------------
924
+ # call MPI_REDUCE(dgm_numer_tmp,dgm_numer,num_models,MPI_DOUBLE_PRECISION,MPI_S...
925
+ # ...
926
+ # if update_A:
927
+ # call MPI_REDUCE(dWtmp,dA,nw*nw*num_models,MPI_DOUBLE_PRECISION,MPI_SUM,0,seg_co...
928
+ nw = config.n_components
929
+ Wtmp_working = torch.zeros(
930
+ (config.n_components, config.n_components),
931
+ dtype=config.dtype, device=config.device
932
+ )
933
+ # if (seg_rank == 0) then
934
+ if config.do_newton and iteration >= config.newt_start:
935
+ newton_terms = compute_newton_terms(
936
+ accumulators=accumulators, config=config, mu=state.mu
937
+ )
938
+
939
+ sigma2 = newton_terms["sigma2"]
940
+ kappa = newton_terms["kappa"]
941
+ lambda_ = newton_terms["lambda_"]
942
+ # if (print_debug) then
943
+ # end if (do_newton .and. iter >= newt_start)
944
+
945
+ #--------------------------FORTRAN CODE-------------------------
946
+ # if (print_debug) then
947
+ # print *, 'dA ', h, ' = '; call flush(6)
948
+ # call DSCAL(nw*nw,dble(-1.0)/dgm_numer(h),dA(:,:,h),1)
949
+ # dA(i,i,h) = dA(i,i,h) + dble(1.0)
950
+ #---------------------------------------------------------------
951
+ if config.do_reject:
952
+ raise NotImplementedError() # pragma: no cover
953
+ else:
954
+ accumulators.dA[:, :] *= -1.0 / accumulators.dgm_numer[0]
955
+
956
+ # basically the same as np.fill_diagonal where fill value is diag + 1.0
957
+ diag = accumulators.dA.diagonal()
958
+ idx = torch.arange(nw)
959
+ accumulators.dA[idx, idx] = diag + 1.0
960
+ # if (print_debug) then
961
+
962
+ if config.do_newton and iteration >= config.newt_start:
963
+ #--------------------------FORTRAN CODE-------------------------
964
+ # do i = 1,nw ... do k = 1,nw
965
+ # if (i == k) then
966
+ # Wtmp(i,i) = dA(i,i,h) / lambda(i,h)
967
+ # else
968
+ # sk1 = sigma2(i,h) * kappa(k,h)
969
+ # sk2 = sigma2(k,h) * kappa(i,h)
970
+ #---------------------------------------------------------------
971
+ # on-diagonal elements
972
+ diag = accumulators.dA.diagonal()
973
+ fill_values = diag / lambda_
974
+ idx = torch.arange(Wtmp_working.shape[0])
975
+ Wtmp_working[idx, idx] = fill_values
976
+
977
+ # off-diagonal elements
978
+ i_indices, k_indices = torch.meshgrid(
979
+ torch.arange(config.n_components, device=config.device),
980
+ torch.arange(config.n_components, device=config.device), indexing='ij',
981
+ )
982
+ off_diag_mask = i_indices != k_indices
983
+ sk1 = sigma2[i_indices] * kappa[k_indices]
984
+ sk2 = sigma2[k_indices] * kappa[i_indices]
985
+ positive_mask = (sk1 * sk2 > 0.0)
986
+ if torch.any(~positive_mask):
987
+ raise RuntimeError(
988
+ "Non-positive definite Hessian encountered in Newton update. "
989
+ f"Iteration {iteration}. Try setting do_newton to False."
990
+ )
991
+ condition_mask = positive_mask & off_diag_mask
992
+ if torch.any(condition_mask):
993
+ # # Wtmp(i,k) = (sk1*dA(i,k,h) - dA(k,i,h)) / (sk1*sk2 - dble(1.0))
994
+ numerator = (
995
+ sk1
996
+ * accumulators.dA[i_indices, k_indices]
997
+ - accumulators.dA[k_indices, i_indices]
998
+ )
999
+ denominator = sk1 * sk2 - 1.0
1000
+ Wtmp_working[condition_mask] = (numerator / denominator)[condition_mask]
1001
+ # end if (i == k)
1002
+ # end do (k)
1003
+ # end do (i)
1004
+ # end if (do_newton .and. iter >= newt_start)
1005
+ if ((not config.do_newton) or (iteration < config.newt_start)):
1006
+ # Wtmp = dA(:,:,h)
1007
+ assert Wtmp_working.shape == accumulators.dA.shape == (nw, nw)
1008
+ Wtmp_working = accumulators.dA.clone()
1009
+ assert Wtmp_working.shape == (nw, nw)
1010
+ #--------------------------FORTRAN CODE-------------------------
1011
+ # call DSCAL(nw*nw,dble(0.0),dA(:,:,h),1)
1012
+ # call DGEMM('N','N',nw,nw,nw,dble(1.0),A(:,comp_list(:,h)),nw,Wtmp,nw,dble...
1013
+ #---------------------------------------------------------------
1014
+ accumulators.dA[:, :] = 0.0
1015
+ accumulators.dA[:, :] += torch.matmul(state.A, Wtmp_working)
1016
+
1017
+ zeta = torch.zeros(config.n_components, dtype=config.dtype, device=config.device)
1018
+ #--------------------------FORTRAN CODE-------------------------
1019
+ # dAk(:,comp_list(i,h)) = dAk(:,comp_list(i,h)) + gm(h)*dA(:,i,h)
1020
+ # zeta(comp_list(i,h)) = zeta(comp_list(i,h)) + gm(h)
1021
+ #---------------------------------------------------------------
1022
+ source_columns = state.gm[0] * accumulators.dA
1023
+ accumulators.dAK[:, :] += source_columns
1024
+ zeta[:] += state.gm[0]
1025
+
1026
+ #--------------------------FORTRAN CODE-------------------------
1027
+ # dAk(:,k) = dAk(:,k) / zeta(k)
1028
+ # nd(iter,:) = sum(dAk*dAk,1)
1029
+ # ndtmpsum = sqrt(sum(nd(iter,:),mask=comp_used) / (nw*count(comp_used)))
1030
+ #---------------------------------------------------------------
1031
+ accumulators.dAK[:,:] /= zeta # Broadcasting division
1032
+ # nd is (num_iters, num_comps) in Fortran, but we only store current iteration
1033
+ nd = torch.sum(accumulators.dAK * accumulators.dAK, dim=0)
1034
+ assert nd.shape == (config.n_components,)
1035
+
1036
+ # comp_used should be a vector of True
1037
+ # In Fortran comp_used was based on component availability.
1038
+ # Unless identify_shared_comps was run. I have no plans to implement that.
1039
+ comp_used = torch.ones(config.n_components, dtype=bool)
1040
+ assert isinstance(comp_used, torch.Tensor)
1041
+ assert comp_used.shape == (config.n_components,)
1042
+ assert comp_used.dtype == torch.bool
1043
+ ndtmpsum = torch.sqrt(torch.sum(nd) / (nw * torch.count_nonzero(comp_used)))
1044
+ # end if (update_A)
1045
+
1046
+ # if (seg_rank == 0) then
1047
+ if config.do_reject:
1048
+ raise NotImplementedError() # pragma: no cover
1049
+ else:
1050
+ # LL(iter) = LLtmp2 / dble(all_blks*nw)
1051
+ # XXX: In the Fortran code LLtmp2 is the summed LLtmps across processes.
1052
+ likelihood = total_LL / (X.shape[0] * nw)
1053
+ return (likelihood, ndtmpsum)
1054
+
1055
+
1056
+ def update_params(
1057
+ *,
1058
+ X,
1059
+ iteration,
1060
+ config,
1061
+ state,
1062
+ accumulators,
1063
+ lrate,
1064
+ rholrate,
1065
+ lrate0,
1066
+ rholrate0,
1067
+ newtrate,
1068
+ wc,
1069
+ ):
1070
+ """Update learnable ICA Parameters, and learning rates."""
1071
+ # if (seg_rank == 0) then
1072
+ # if update_gm:
1073
+ if config.do_reject:
1074
+ raise NotImplementedError() # pragma: no cover
1075
+ # gm = dgm_numer / dble(numgoodsum)
1076
+ else:
1077
+ state.gm[:] = accumulators.dgm_numer / X.shape[0]
1078
+ # end if (update_gm)
1079
+
1080
+ # if update_alpha:
1081
+ # assert alpha.shape == (num_comps, num_mix)
1082
+ state.alpha[:, :] = accumulators.dalpha_numer / accumulators.dalpha_denom
1083
+ if torch.any(~torch.isfinite(state.alpha)):
1084
+ raise RuntimeError("Non-finite alpha encountered during update.")
1085
+
1086
+ # if update_c:
1087
+ # assert c.shape == (nw, num_models)
1088
+ state.c[:] = accumulators.dc_numer / accumulators.dc_denom
1089
+ if torch.any(~torch.isfinite(state.c)):
1090
+ raise RuntimeError("Non-finite c encountered during update.")
1091
+
1092
+ # === Section: Apply Parameter accumulators & Rescale ===
1093
+ # Apply accumulated statistics to update parameters, then rescale and refresh W/wc.
1094
+ # !print *, 'updating A ...'; call flush(6)
1095
+ if (iteration < share_start or (iteration % share_iter > 5)):
1096
+ if config.do_newton and (iteration >= config.newt_start):
1097
+ # lrate = min( newtrate, lrate + min(dble(1.0)/dble(newt_ramp),lrate) )
1098
+ # rholrate = rholrate0
1099
+ # call DAXPY(nw*num_comps,dble(-1.0)*lrate,dAk,1,A,1)
1100
+ lrate = min(newtrate, lrate + min(1.0 / config.newt_ramp, lrate))
1101
+ rholrate = rholrate0
1102
+ state.A -= lrate * accumulators.dAK
1103
+ else:
1104
+ lrate = min(lrate0, lrate + min(1 / config.newt_ramp, lrate))
1105
+ rholrate = rholrate0
1106
+ # call DAXPY(nw*num_comps,dble(-1.0)*lrate,dAk,1,A,1)
1107
+ state.A -= lrate * accumulators.dAK
1108
+ # end if do_newton
1109
+ # end if (update_A)
1110
+
1111
+ # if update_mu:
1112
+ state.mu += accumulators.dmu_numer / accumulators.dmu_denom
1113
+ if torch.any(~torch.isfinite(state.mu)):
1114
+ raise RuntimeError("Non-finite mu encountered during update.")
1115
+
1116
+ # if update_beta:
1117
+ state.sbeta *= torch.sqrt(accumulators.dbeta_numer / accumulators.dbeta_denom)
1118
+ sbetatmp = torch.minimum(torch.tensor(invsigmax), state.sbeta)
1119
+ state.sbeta = torch.maximum(torch.tensor(invsigmin), sbetatmp)
1120
+ if torch.any(~torch.isfinite(state.sbeta)):
1121
+ raise RuntimeError("Non-finite sbeta encountered during update.")
1122
+
1123
+
1124
+ state.rho += (
1125
+ rholrate
1126
+ * (
1127
+ 1.0
1128
+ - (state.rho / torch.special.psi(1.0 + 1.0 / state.rho))
1129
+ * accumulators.drho_numer
1130
+ / accumulators.drho_denom
1131
+ )
1132
+ )
1133
+ rhotmp = torch.minimum(torch.tensor(maxrho), state.rho) # shape (num_comps, num_mix)
1134
+ assert rhotmp.shape == (config.n_components, config.n_mixtures)
1135
+ state.rho = torch.maximum(torch.tensor(minrho), rhotmp)
1136
+
1137
+ # !--- rescale
1138
+ # !print *, 'rescaling A ...'; call flush(6)
1139
+ # from seed import A_FORTRAN
1140
+ if doscaling:
1141
+ # calculate the L2 norm for each column of A and then use it to normalize that
1142
+ # column and scale the corresponding columns in mu and sbeta, but only if the
1143
+ # norm is positive.
1144
+ Anrmk = torch.linalg.norm(state.A, dim=0)
1145
+ positive_mask = Anrmk > 0
1146
+ if positive_mask.all():
1147
+ state.A[:, positive_mask] /= Anrmk[positive_mask]
1148
+ state.mu[positive_mask, :] *= Anrmk[positive_mask, None]
1149
+ state.sbeta[positive_mask, :] /= Anrmk[positive_mask, None]
1150
+ else:
1151
+ raise NotImplementedError() # pragma: no cover
1152
+ # end if (doscaling)
1153
+
1154
+ if share_comps:
1155
+ raise NotImplementedError() # pragma: no cover
1156
+
1157
+ state.W, wc = get_unmixing_matrices(
1158
+ c=state.c,
1159
+ A=state.A,
1160
+ W=state.W,
1161
+ )
1162
+ # if (print_debug) then
1163
+ # call MPI_BCAST(gm,num_models,MPI_DOUBLE_PRECISION,0,seg_comm,ierr)
1164
+ # ...
1165
+ return lrate, rholrate, state, wc