redbirdpy 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
redbirdpy/recon.py ADDED
@@ -0,0 +1,893 @@
1
+ """
2
+ Redbird Reconstruction Module - Inverse problem solvers for DOT/NIRS.
3
+
4
+ INDEX CONVENTION: All mesh indices (elem, face) stored in cfg/recon are 1-based
5
+ to match MATLAB/iso2mesh. Conversion to 0-based occurs only when indexing numpy
6
+ arrays, using local variables named with '_0' suffix.
7
+
8
+ Functions:
9
+ runrecon: Main reconstruction driver with iterative Gauss-Newton
10
+ reginv: Regularized matrix inversion (auto-selects over/under-determined)
11
+ reginvover: Overdetermined least-squares solver
12
+ reginvunder: Underdetermined least-squares solver
13
+ matreform: Reformat matrix equation for different output forms
14
+ prior: Generate structure-prior regularization matrices
15
+ """
16
+
17
+ __all__ = [
18
+ "runrecon",
19
+ "reginv",
20
+ "reginvover",
21
+ "reginvunder",
22
+ "matreform",
23
+ "matflat",
24
+ "prior",
25
+ "syncprop",
26
+ ]
27
+
28
+ import numpy as np
29
+ from scipy import sparse
30
+ from scipy.sparse.linalg import spsolve
31
+ from typing import Dict, Tuple, Optional, Union, List, Any
32
+ import warnings
33
+
34
+ from .forward import runforward, jac, jacchrome
35
+ from .utility import sdmap, meshinterp
36
+ from .property import updateprop
37
+
38
+
39
+ def runrecon(
40
+ cfg: dict,
41
+ recon: dict,
42
+ detphi0: Union[np.ndarray, dict],
43
+ sd: Union[np.ndarray, dict] = None,
44
+ **kwargs,
45
+ ) -> tuple:
46
+ """
47
+ Perform iterative Gauss-Newton reconstruction.
48
+
49
+ Parameters
50
+ ----------
51
+ cfg : dict
52
+ Forward simulation structure (forward mesh). elem/face are 1-based.
53
+ recon : dict
54
+ Reconstruction structure containing:
55
+ - node, elem: Reconstruction mesh (optional, for dual-mesh). elem is 1-based.
56
+ - param: Initial chromophore concentrations
57
+ - prop: Initial optical properties
58
+ - lambda: Regularization parameter
59
+ - bulk: Background property values
60
+ - mapid: Forward-to-recon mesh mapping (0-based element indices)
61
+ - mapweight: Barycentric weights for mapping
62
+ detphi0 : ndarray or dict
63
+ Measured data to fit
64
+ sd : ndarray or dict
65
+ Source-detector mapping (0-based column indices into phi matrix)
66
+ **kwargs : dict
67
+ Options: maxiter, lambda_, tol, reform, report, prior
68
+
69
+ Returns
70
+ -------
71
+ recon : dict
72
+ Updated reconstruction with fitted properties
73
+ resid : ndarray
74
+ Residual at each iteration
75
+ cfg : dict
76
+ Updated forward structure
77
+ """
78
+ import time
79
+
80
+ # Parse options
81
+ maxiter = kwargs.get("maxiter", 5)
82
+ lambda_ = kwargs.get("lambda_", recon.get("lambda", 0.05))
83
+ report = kwargs.get("report", True)
84
+ tol = kwargs.get("tol", 0)
85
+ reform = kwargs.get("reform", "real")
86
+ solverflag = kwargs.get("solverflag", {})
87
+ rfcw = kwargs.get("rfcw", [1])
88
+ prior_type = kwargs.get("prior", "")
89
+
90
+ if isinstance(rfcw, int):
91
+ rfcw = [rfcw]
92
+
93
+ if sd is None:
94
+ sd = sdmap(cfg)
95
+
96
+ # Normalize recon["prop"] to always be 2D
97
+ if "prop" in recon:
98
+ if isinstance(recon["prop"], np.ndarray):
99
+ if recon["prop"].ndim == 1:
100
+ recon["prop"] = recon["prop"].reshape(1, -1)
101
+ elif isinstance(recon["prop"], dict):
102
+ for key in recon["prop"]:
103
+ if (
104
+ isinstance(recon["prop"][key], np.ndarray)
105
+ and recon["prop"][key].ndim == 1
106
+ ):
107
+ recon["prop"][key] = recon["prop"][key].reshape(1, -1)
108
+
109
+ # Determine if this is label-based reconstruction
110
+ # Label-based: recon["prop"] has few rows (matching number of tissue labels)
111
+ # Node-based: recon["prop"] has many rows (matching number of nodes)
112
+ is_label_based = False
113
+ if "prop" in recon and isinstance(recon["prop"], np.ndarray):
114
+ n_prop_rows = recon["prop"].shape[0]
115
+ # If prop has fewer rows than a reasonable mesh would have nodes,
116
+ # it's label-based. Typical meshes have 100+ nodes.
117
+ if n_prop_rows < 50:
118
+ is_label_based = True
119
+ # Create seg array if not present, assuming all elements use label 1
120
+ if "seg" not in recon and "elem" in cfg:
121
+ recon["seg"] = np.ones(cfg["elem"].shape[0], dtype=int)
122
+
123
+ resid = np.zeros(maxiter)
124
+ updates = []
125
+
126
+ # Build regularization matrix if needed
127
+ Aregu = {}
128
+ if "lmat" in kwargs:
129
+ Aregu["lmat"] = kwargs["lmat"]
130
+ elif "ltl" in kwargs:
131
+ Aregu["ltl"] = kwargs["ltl"]
132
+ elif prior_type and "seg" in recon:
133
+ Aregu["lmat"] = prior(recon["seg"], prior_type, kwargs)
134
+
135
+ # Determine if using dual mesh
136
+ dual_mesh = "node" in recon and "elem" in recon and "mapid" in recon
137
+
138
+ # Main iteration loop
139
+ for iteration in range(maxiter):
140
+ t_start = time.time()
141
+
142
+ # Sync properties between recon and forward mesh
143
+ if "param" in recon or "prop" in recon:
144
+ cfg, recon = syncprop(cfg, recon)
145
+
146
+ # Update cfg.prop from cfg.param if multi-spectral
147
+ if "param" in cfg and isinstance(cfg.get("prop"), dict):
148
+ cfg["prop"] = updateprop(cfg)
149
+
150
+ # Run forward simulation
151
+ detphi, phi = runforward(cfg, solverflag=solverflag, sd=sd, rfcw=rfcw)
152
+
153
+ # Build Jacobians
154
+ wavelengths = [""]
155
+ if isinstance(cfg.get("prop"), dict):
156
+ wavelengths = list(cfg["prop"].keys())
157
+
158
+ Jmua = {}
159
+
160
+ for wv in wavelengths:
161
+ sdwv = sd.get(wv, sd) if isinstance(sd, dict) else sd
162
+ phiwv = phi.get(wv, phi) if isinstance(phi, dict) else phi
163
+
164
+ Jmua_n, Jmua_e = jac(
165
+ sdwv, phiwv, cfg["deldotdel"], cfg["elem"], cfg["evol"]
166
+ )
167
+ # Use "mua" as key for single-wavelength case
168
+ key = "mua" if wv == "" else wv
169
+ Jmua[key] = Jmua_n
170
+
171
+ # Build chromophore Jacobians if multi-spectral
172
+ if isinstance(cfg.get("prop"), dict) and "param" in cfg:
173
+ chromophores = [
174
+ k
175
+ for k in cfg["param"].keys()
176
+ if k in ["hbo", "hbr", "water", "lipids", "aa3"]
177
+ ]
178
+ if chromophores:
179
+ Jmua = jacchrome(Jmua, chromophores)
180
+
181
+ # Flatten measurement data
182
+ detphi0_flat = _flatten_detphi(detphi0, sd, wavelengths, rfcw)
183
+ detphi_flat = _flatten_detphi(detphi, sd, wavelengths, rfcw)
184
+
185
+ # Get block structure
186
+ if isinstance(Jmua, dict):
187
+ blocks = {k: v.shape for k, v in Jmua.items()}
188
+ else:
189
+ blocks = {"mua": Jmua.shape}
190
+
191
+ # Flatten Jacobian
192
+ Jflat = matflat(Jmua)
193
+
194
+ # Reformat for real-valued solver if needed
195
+ if reform != "complex":
196
+ Jflat, misfit, nblock = matreform(Jflat, detphi0_flat, detphi_flat, reform)
197
+ else:
198
+ misfit = detphi0_flat - detphi_flat
199
+
200
+ # Map Jacobian to recon mesh if dual-mesh
201
+ if dual_mesh:
202
+ Jflat = _remap_jacobian(Jflat, recon, cfg)
203
+ # Update blocks to reflect recon mesh size
204
+ nn_recon = recon["node"].shape[0]
205
+ blocks = {k: (v[0], nn_recon) for k, v in blocks.items()}
206
+
207
+ # Compress for segmented reconstruction ONLY if:
208
+ # 1. seg array length matches Jacobian columns (node-based seg for compression)
209
+ # 2. Few unique labels (true label-based reconstruction, not element segmentation)
210
+ if "seg" in recon and np.ndim(recon["seg"]) == 1:
211
+ seg = recon["seg"]
212
+ n_jac_cols = Jflat.shape[1]
213
+ n_labels = len(np.unique(seg))
214
+
215
+ # Only compress if seg is node-based (matches Jac columns) with few labels
216
+ if len(seg) == n_jac_cols and n_labels < 50:
217
+ Jflat = _masksum(Jflat, seg)
218
+ # Update blocks to reflect compressed size
219
+ blocks = {k: (v[0], n_labels) for k, v in blocks.items()}
220
+
221
+ # Store residual
222
+ resid[iteration] = np.sum(np.abs(misfit))
223
+
224
+ # Prepare regularization
225
+ if iteration == 0 and Aregu:
226
+ if "lmat" in Aregu and "ltl" not in Aregu:
227
+ if Jflat.shape[0] >= Jflat.shape[1]:
228
+ Aregu["ltl"] = Aregu["lmat"].T @ Aregu["lmat"]
229
+ else:
230
+ from scipy.linalg import qr
231
+
232
+ _, Aregu["lir"] = qr(Aregu["lmat"])
233
+ Aregu["lir"] = np.linalg.inv(np.triu(Aregu["lir"]))
234
+
235
+ blockscale = 1.0 / np.sqrt(np.sum(Jflat**2))
236
+ Jflat = Jflat * blockscale
237
+
238
+ # Solve inverse problem
239
+ dmu = reginv(Jflat, misfit, lambda_, Aregu, blocks, **solverflag)
240
+ dmu = dmu * blockscale
241
+
242
+ # Parse update and apply to recon structure
243
+ update = {}
244
+ idx = 0
245
+ output_keys = list(blocks.keys())
246
+
247
+ for key in output_keys:
248
+ size = blocks[key][1]
249
+ dx = dmu[idx : idx + size]
250
+ update[key] = dx
251
+ idx += size
252
+
253
+ # Apply update to recon structure (not cfg!)
254
+ if key in ["mua", "dcoeff"]:
255
+ propidx = 0 if key == "mua" else 1
256
+ if "prop" in recon and isinstance(recon["prop"], np.ndarray):
257
+ prop = recon["prop"]
258
+ n_prop_rows = prop.shape[0]
259
+
260
+ # Determine if label-based by comparing prop rows to dx length
261
+ if n_prop_rows < len(dx) and n_prop_rows < 50:
262
+ # Label-based: prop has one row per tissue label
263
+ n_updates = min(n_prop_rows, len(dx))
264
+ for li in range(n_updates):
265
+ if key == "dcoeff":
266
+ old_dcoeff = 1.0 / (3 * prop[li, propidx])
267
+ new_dcoeff = old_dcoeff + dx[li]
268
+ recon["prop"][li, propidx] = 1.0 / (3 * new_dcoeff)
269
+ else:
270
+ recon["prop"][li, propidx] += dx[li]
271
+ else:
272
+ # Node/element based: prop has one row per node
273
+ if key == "dcoeff":
274
+ old_dcoeff = 1.0 / (3 * prop[:, propidx])
275
+ new_dcoeff = old_dcoeff + dx
276
+ recon["prop"][:, propidx] = 1.0 / (3 * new_dcoeff)
277
+ else:
278
+ recon["prop"][:, propidx] += dx
279
+
280
+ elif key in ["hbo", "hbr", "water", "lipids", "scatamp", "scatpow"]:
281
+ # Determine target: recon["param"] if present, else cfg["param"]
282
+ if "param" in recon and key in recon["param"]:
283
+ target = recon
284
+ elif "param" in cfg and key in cfg["param"]:
285
+ target = cfg
286
+ else:
287
+ continue
288
+
289
+ param_val = target["param"][key]
290
+
291
+ # Get length of parameter (scalar vs array)
292
+ if hasattr(param_val, "__len__"):
293
+ n_param = len(param_val)
294
+ else:
295
+ n_param = 1
296
+
297
+ # Determine if label-based by comparing param length to dx length
298
+ if n_param < len(dx) and n_param < 50:
299
+ # Label-based: param has one value per tissue label
300
+ if n_param == 1:
301
+ # Scalar parameter
302
+ if hasattr(param_val, "__len__"):
303
+ target["param"][key][0] += dx[0]
304
+ else:
305
+ target["param"][key] += dx[0]
306
+ else:
307
+ # Array parameter with few elements (labels)
308
+ n_updates = min(n_param, len(dx))
309
+ for li in range(n_updates):
310
+ target["param"][key][li] += dx[li]
311
+ else:
312
+ # Node/element based: param has one value per node
313
+ target["param"][key] = param_val + dx
314
+
315
+ updates.append(update)
316
+
317
+ if report:
318
+ elapsed = time.time() - t_start
319
+ rel_resid = resid[iteration] / resid[0] if iteration > 0 else 1.0
320
+ print(
321
+ f"iter [{iteration + 1:4d}]: residual={resid[iteration]:.6e}, "
322
+ f"relres={rel_resid:.6e} lambda={lambda_:.6e} (time={elapsed:.2f} s)"
323
+ )
324
+
325
+ # Check convergence
326
+ if (
327
+ iteration > 0
328
+ and abs(resid[iteration] - resid[iteration - 1]) / resid[0] < tol
329
+ ):
330
+ resid = resid[: iteration + 1]
331
+ break
332
+
333
+ recon["lambda"] = lambda_
334
+
335
+ return recon, resid, cfg, updates, Jmua, detphi, phi
336
+
337
+
338
+ def reginv(
339
+ Amat: np.ndarray,
340
+ rhs: np.ndarray,
341
+ lambda_: float,
342
+ Areg: dict = None,
343
+ blocks: dict = None,
344
+ **kwargs,
345
+ ) -> np.ndarray:
346
+ """
347
+ Solve regularized linear system, auto-selecting method.
348
+
349
+ Automatically chooses overdetermined or underdetermined solver
350
+ based on matrix dimensions.
351
+ """
352
+ if Areg is None:
353
+ Areg = {}
354
+
355
+ if Amat.shape[0] >= Amat.shape[1]:
356
+ LTL = Areg.get("ltl", None)
357
+ return reginvover(Amat, rhs, lambda_, LTL, blocks, **kwargs)
358
+ else:
359
+ invR = Areg.get("lir", None)
360
+ return reginvunder(Amat, rhs, lambda_, invR, blocks, **kwargs)
361
+
362
+
363
+ def reginvover(
364
+ Amat: np.ndarray,
365
+ rhs: np.ndarray,
366
+ lambda_: float,
367
+ LTL: np.ndarray = None,
368
+ blocks: dict = None,
369
+ **kwargs,
370
+ ) -> np.ndarray:
371
+ """
372
+ Solve overdetermined Gauss-Newton normal equation.
373
+
374
+ Solves: delta_mu = inv(J'J + lambda*(L'L)) * J' * (y - phi)
375
+ """
376
+ # Remove zero-sensitivity columns
377
+ col_sum = np.sum(np.abs(Amat), axis=0)
378
+ idx0 = np.where(col_sum != 0)[0]
379
+ length0 = Amat.shape[1]
380
+
381
+ if len(idx0) < length0:
382
+ Amat = Amat[:, idx0]
383
+ if LTL is not None and LTL.shape[0] > len(idx0):
384
+ Lidx = idx0[idx0 < LTL.shape[0]]
385
+ LTL = LTL[np.ix_(Lidx, Lidx)]
386
+
387
+ # Remove zero-data rows
388
+ row_sum = np.sum(np.abs(Amat), axis=1)
389
+ valid_rows = row_sum != 0
390
+ if np.sum(valid_rows) < Amat.shape[0]:
391
+ Amat = Amat[valid_rows, :]
392
+ rhs = rhs[valid_rows]
393
+
394
+ # Build normal equation
395
+ rhs_proj = Amat.T @ rhs.flatten()
396
+ Hess = Amat.T @ Amat
397
+
398
+ # Add regularization
399
+ if LTL is None:
400
+ Hess[np.diag_indices_from(Hess)] += lambda_
401
+ else:
402
+ if Hess.shape[0] == LTL.shape[0]:
403
+ Hess = Hess + lambda_ * LTL
404
+ else:
405
+ nx = LTL.shape[0]
406
+ for i in range(0, Hess.shape[0], nx):
407
+ end_i = min(i + nx, Hess.shape[0])
408
+ Hess[i:end_i, i:end_i] = (
409
+ Hess[i:end_i, i:end_i] + lambda_ * LTL[: end_i - i, : end_i - i]
410
+ )
411
+
412
+ # Normalize and solve
413
+ Hess_norm, Gdiag = _normalize_diag(Hess)
414
+
415
+ if sparse.issparse(Hess_norm):
416
+ res = Gdiag * spsolve(Hess_norm, Gdiag * rhs_proj)
417
+ else:
418
+ res = Gdiag * np.linalg.solve(Hess_norm, Gdiag * rhs_proj)
419
+
420
+ # Restore full-length result
421
+ if len(idx0) < length0:
422
+ res_full = np.zeros(length0)
423
+ res_full[idx0] = res
424
+ res = res_full
425
+
426
+ return res
427
+
428
+
429
+ def reginvunder(
430
+ Amat: np.ndarray,
431
+ rhs: np.ndarray,
432
+ lambda_: float,
433
+ invR: np.ndarray = None,
434
+ blocks: dict = None,
435
+ **kwargs,
436
+ ) -> np.ndarray:
437
+ """
438
+ Solve underdetermined Gauss-Newton equation.
439
+
440
+ Solves: delta_mu = inv(L'L)*J'*inv(J*inv(L'L)*J' + lambda*I)*(y-phi)
441
+ """
442
+ Alen = Amat.shape[1]
443
+
444
+ # Remove zero columns
445
+ col_sum = np.sum(np.abs(Amat), axis=0)
446
+ idx = np.where(col_sum != 0)[0]
447
+ if len(idx) < Alen:
448
+ Amat = Amat[:, idx]
449
+
450
+ # Remove zero rows
451
+ row_sum = np.sum(np.abs(Amat), axis=1)
452
+ valid_rows = row_sum != 0
453
+ if np.sum(valid_rows) < Amat.shape[0]:
454
+ Amat = Amat[valid_rows, :]
455
+ rhs = rhs[valid_rows]
456
+
457
+ # Apply regularization transform
458
+ if invR is not None:
459
+ nx = invR.shape[0]
460
+ if nx == Amat.shape[1]:
461
+ Amat = Amat @ invR
462
+ elif blocks is not None:
463
+ block_keys = list(blocks.keys())
464
+ cumlen = np.cumsum([0] + [blocks[k][1] for k in block_keys])
465
+ for i, k in enumerate(block_keys):
466
+ if cumlen[i + 1] - cumlen[i] == nx:
467
+ Amat[:, cumlen[i] : cumlen[i + 1]] = (
468
+ Amat[:, cumlen[i] : cumlen[i + 1]] @ invR
469
+ )
470
+
471
+ rhs = rhs.flatten()
472
+
473
+ # Build Hessian in dual space
474
+ Hess = Amat @ Amat.T
475
+ Hess[np.diag_indices_from(Hess)] += lambda_
476
+
477
+ # Normalize and solve
478
+ Hess_norm, Gdiag = _normalize_diag(Hess)
479
+
480
+ if sparse.issparse(Hess_norm):
481
+ y = Gdiag * spsolve(Hess_norm, Gdiag * rhs)
482
+ else:
483
+ y = Gdiag * np.linalg.solve(Hess_norm, Gdiag * rhs)
484
+
485
+ # Transform back to primal space
486
+ if invR is not None:
487
+ nx = invR.shape[0]
488
+ if nx == Amat.shape[1]:
489
+ res = invR @ (Amat.T @ y)
490
+ else:
491
+ res = Amat.T @ y
492
+ if blocks is not None:
493
+ block_keys = list(blocks.keys())
494
+ cumlen = np.cumsum([0] + [blocks[k][1] for k in block_keys])
495
+ for i, k in enumerate(block_keys):
496
+ if cumlen[i + 1] - cumlen[i] == nx:
497
+ res[cumlen[i] : cumlen[i + 1]] = (
498
+ invR @ res[cumlen[i] : cumlen[i + 1]]
499
+ )
500
+ else:
501
+ res = Amat.T @ y
502
+
503
+ # Restore full length
504
+ if len(idx) < Alen:
505
+ res_full = np.zeros(Alen)
506
+ res_full[idx] = res
507
+ res = res_full
508
+
509
+ return res
510
+
511
+
512
+ def matreform(
513
+ Amat: np.ndarray, ymeas: np.ndarray, ymodel: np.ndarray, form: str = "complex"
514
+ ) -> Tuple[np.ndarray, np.ndarray, int]:
515
+ """
516
+ Reformat matrix equation for different output forms.
517
+
518
+ Parameters
519
+ ----------
520
+ form : str
521
+ 'complex': No transformation
522
+ 'real': Real-valued system
523
+ 'reim': Expand to [Re(x); Im(x)]
524
+ 'logphase': Log-amplitude and phase form
525
+ """
526
+ nblock = 1
527
+ rhs = ymeas - ymodel
528
+
529
+ if form == "complex":
530
+ return Amat, rhs, nblock
531
+
532
+ if form in ["real", "reim"]:
533
+ newA = np.real(Amat)
534
+ newrhs = np.real(rhs)
535
+
536
+ if not np.isreal(rhs).all() and not np.isreal(Amat).all():
537
+ if form == "reim":
538
+ newA = np.block(
539
+ [[np.real(Amat), -np.imag(Amat)], [np.imag(Amat), np.real(Amat)]]
540
+ )
541
+ else:
542
+ newA = np.vstack([np.real(Amat), np.imag(Amat)])
543
+ newrhs = np.concatenate([np.real(rhs), np.imag(rhs)])
544
+ nblock = 2
545
+
546
+ return newA, newrhs, nblock
547
+
548
+ if form == "logphase":
549
+ temp = np.conj(ymodel) / np.abs(ymodel * ymodel)
550
+ temp = temp[:, np.newaxis] * Amat if Amat.ndim == 2 else temp * Amat
551
+
552
+ if np.isreal(ymodel).all():
553
+ newA = np.real(temp)
554
+ newrhs = np.log(np.abs(ymeas)) - np.log(np.abs(ymodel))
555
+ else:
556
+ newA = np.vstack([np.real(temp), np.imag(temp)])
557
+ newrhs = np.concatenate(
558
+ [
559
+ np.log(np.abs(ymeas)) - np.log(np.abs(ymodel)),
560
+ np.angle(ymeas) - np.angle(ymodel),
561
+ ]
562
+ )
563
+ nblock = 2
564
+
565
+ return newA, newrhs, nblock
566
+
567
+ raise ValueError(f"Unknown form: {form}")
568
+
569
+
570
+ def matflat(Amat: Union[dict, np.ndarray], weight: np.ndarray = None) -> np.ndarray:
571
+ """Flatten dict of matrices into single 2D matrix."""
572
+ if isinstance(Amat, np.ndarray):
573
+ return Amat
574
+
575
+ if isinstance(Amat, dict):
576
+ keys = list(Amat.keys())
577
+ if weight is None:
578
+ weight = np.ones(len(keys))
579
+
580
+ first_val = Amat[keys[0]]
581
+ if isinstance(first_val, dict):
582
+ # Multi-wavelength: vertically concatenate
583
+ inner_keys = list(first_val.keys())
584
+ Anew = []
585
+ for wv in inner_keys:
586
+ row = np.hstack([Amat[k][wv] * weight[j] for j, k in enumerate(keys)])
587
+ Anew.append(row)
588
+ return np.vstack(Anew)
589
+ else:
590
+ # Single wavelength: horizontally concatenate
591
+ return np.hstack([Amat[k] * weight[i] for i, k in enumerate(keys)])
592
+
593
+ return Amat
594
+
595
+
596
+ def prior(seg: np.ndarray, priortype: str, params: dict = None) -> np.ndarray:
597
+ """
598
+ Generate structure-prior regularization matrix.
599
+
600
+ Parameters
601
+ ----------
602
+ seg : ndarray
603
+ Segmentation labels (node or element based) or composition matrix
604
+ priortype : str
605
+ 'laplace': Laplacian prior within segments
606
+ 'helmholtz': Helmholtz-like prior with beta parameter
607
+ 'comp': Compositional prior for soft segmentation
608
+ """
609
+ if not priortype:
610
+ return None
611
+
612
+ params = params or {}
613
+
614
+ if np.ndim(seg) == 1:
615
+ # Label-based segmentation
616
+ labels, inverse = np.unique(seg, return_inverse=True)
617
+ counts = np.bincount(inverse)
618
+ n = len(seg)
619
+
620
+ if priortype == "laplace":
621
+ Lmat = np.eye(n)
622
+ for i, label in enumerate(labels):
623
+ idx = np.where(inverse == i)[0]
624
+ if counts[i] > 1:
625
+ Lmat[np.ix_(idx, idx)] = -1.0 / counts[i]
626
+ np.fill_diagonal(Lmat, 1.0)
627
+ return Lmat
628
+
629
+ elif priortype == "helmholtz":
630
+ beta = params.get("beta", 1.0)
631
+ Lmat = np.eye(n)
632
+ for i, label in enumerate(labels):
633
+ idx = np.where(inverse == i)[0]
634
+ if counts[i] > 1:
635
+ Lmat[np.ix_(idx, idx)] = -1.0 / (counts[i] + beta)
636
+ np.fill_diagonal(Lmat, 1.0)
637
+ return Lmat
638
+
639
+ elif priortype == "comp" and seg.ndim == 2:
640
+ # Compositional prior for soft segmentation
641
+ alpha = params.get("alpha", 0.1)
642
+ beta = params.get("beta", 1.0)
643
+ n = seg.shape[0]
644
+ nc = seg.shape[1]
645
+
646
+ Lmat = sparse.lil_matrix((n, n))
647
+
648
+ for i in range(n):
649
+ for j in range(i + 1, n):
650
+ dval = np.sum(np.abs(seg[i, :] - seg[j, :]))
651
+ if dval < alpha * nc:
652
+ val = -alpha - dval / nc
653
+ Lmat[i, j] = val
654
+ Lmat[j, i] = val
655
+
656
+ # Normalize rows
657
+ rowsum = np.abs(np.array(Lmat.sum(axis=1)).flatten())
658
+ for i in range(n):
659
+ for j in range(n):
660
+ if Lmat[i, j] != 0 and i != j:
661
+ Lmat[i, j] /= beta * np.sqrt(rowsum[i] * rowsum[j] + 1e-16)
662
+
663
+ Lmat = Lmat + sparse.eye(n)
664
+ return Lmat.tocsr()
665
+
666
+ return None
667
+
668
+
669
+ def syncprop(cfg: dict, recon: dict) -> Tuple[dict, dict]:
670
+ """
671
+ Synchronize properties between forward and reconstruction meshes.
672
+
673
+ Handles both single-mesh and dual-mesh reconstruction scenarios.
674
+
675
+ For dual-mesh reconstruction:
676
+ - recon mesh is typically coarser than forward mesh
677
+ - mapid/mapweight map FORWARD mesh nodes to RECON mesh elements
678
+ - We interpolate from recon mesh to forward mesh
679
+
680
+ mapid contains 1-based element indices into recon["elem"].
681
+ """
682
+ # Use iso2mesh's meshinterp for interpolation
683
+ try:
684
+ from iso2mesh import meshinterp
685
+ except ImportError:
686
+ from .utility import meshinterp
687
+
688
+ # Determine mesh sizes
689
+ cfg_nn = cfg["node"].shape[0]
690
+ cfg_ne = cfg["elem"].shape[0]
691
+
692
+ if "node" in recon and "elem" in recon:
693
+ recon_nn = recon["node"].shape[0]
694
+ recon_ne = recon["elem"].shape[0]
695
+ else:
696
+ recon_nn = cfg_nn
697
+ recon_ne = cfg_ne
698
+
699
+ # Threshold to distinguish label-based from node/element-based
700
+ # Use a small number that's clearly less than any reasonable mesh size
701
+ label_threshold = 50
702
+
703
+ if "param" in recon:
704
+ # Map recon.param to cfg.param
705
+ allkeys = list(recon["param"].keys())
706
+ first_param = recon["param"][allkeys[0]]
707
+ param_len = len(first_param) if hasattr(first_param, "__len__") else 1
708
+
709
+ if param_len < label_threshold:
710
+ # Label-based - direct copy (no interpolation needed)
711
+ cfg["param"] = {
712
+ k: v.copy() if hasattr(v, "copy") else v
713
+ for k, v in recon["param"].items()
714
+ }
715
+ else:
716
+ # Node/element based - need interpolation for dual-mesh
717
+ if "param" not in cfg:
718
+ cfg["param"] = {}
719
+
720
+ for key in allkeys:
721
+ if "mapid" in recon and "mapweight" in recon and "elem" in recon:
722
+ # Interpolate from recon mesh to forward mesh nodes
723
+ # Result should have cfg_nn rows - pass None for toval
724
+ cfg["param"][key] = meshinterp(
725
+ recon["param"][key],
726
+ recon["mapid"],
727
+ recon["mapweight"],
728
+ recon["elem"], # 1-based, meshinterp converts
729
+ None, # Create new array of correct size
730
+ )
731
+ else:
732
+ # Same mesh - direct copy
733
+ cfg["param"][key] = recon["param"][key].copy()
734
+
735
+ elif "prop" in recon:
736
+ # Map recon.prop to cfg.prop
737
+ if not isinstance(recon["prop"], dict):
738
+ recon_prop_len = recon["prop"].shape[0]
739
+
740
+ if recon_prop_len < label_threshold:
741
+ # Label-based recon prop - direct copy
742
+ cfg["prop"] = recon["prop"].copy()
743
+ elif "mapid" in recon and "mapweight" in recon and "elem" in recon:
744
+ # Node-based recon prop with dual mesh - interpolate to forward mesh
745
+ # The result should be node-based on the FORWARD mesh (cfg_nn rows)
746
+ # Pass None for toval to create new array of correct size
747
+ cfg["prop"] = meshinterp(
748
+ recon["prop"],
749
+ recon["mapid"],
750
+ recon["mapweight"],
751
+ recon["elem"],
752
+ None, # Don't pass cfg["prop"] - it may be label-based
753
+ )
754
+ else:
755
+ # Same mesh or no mapping - direct copy
756
+ cfg["prop"] = recon["prop"].copy()
757
+ else:
758
+ # Multi-wavelength
759
+ allkeys = list(recon["prop"].keys())
760
+ first_prop = recon["prop"][allkeys[0]]
761
+ recon_prop_len = first_prop.shape[0]
762
+
763
+ if recon_prop_len < label_threshold:
764
+ # Label-based - direct copy
765
+ cfg["prop"] = {k: v.copy() for k, v in recon["prop"].items()}
766
+ elif "mapid" in recon and "mapweight" in recon and "elem" in recon:
767
+ # Node-based with dual mesh - interpolate
768
+ cfg["prop"] = {}
769
+ for k in allkeys:
770
+ cfg["prop"][k] = meshinterp(
771
+ recon["prop"][k],
772
+ recon["mapid"],
773
+ recon["mapweight"],
774
+ recon["elem"],
775
+ None, # Create new array of correct size
776
+ )
777
+ else:
778
+ # Same mesh - direct copy
779
+ cfg["prop"] = {k: v.copy() for k, v in recon["prop"].items()}
780
+
781
+ return cfg, recon
782
+
783
+
784
+ def _normalize_diag(A: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
785
+ """Normalize matrix to have unit diagonal for better conditioning."""
786
+ Adiag = np.diag(A)
787
+ di = 1.0 / np.sqrt(np.abs(Adiag) + 1e-16)
788
+ Anorm = (di[:, np.newaxis] * di[np.newaxis, :]) * A
789
+ return Anorm, di
790
+
791
+
792
+ def _flatten_detphi(
793
+ detphi: Union[np.ndarray, dict],
794
+ sd: Union[np.ndarray, dict],
795
+ wavelengths: List[str],
796
+ rfcw: List[int],
797
+ ) -> np.ndarray:
798
+ """Flatten detector measurements from nested dict to 1D array."""
799
+ if isinstance(detphi, np.ndarray):
800
+ return detphi.flatten()
801
+
802
+ result = []
803
+ for wv in wavelengths:
804
+ if isinstance(detphi, dict):
805
+ phi_wv = detphi.get(wv, detphi)
806
+ else:
807
+ phi_wv = detphi
808
+
809
+ if isinstance(phi_wv, dict):
810
+ for md in rfcw:
811
+ result.extend(phi_wv.get(md, {}).get("detphi", phi_wv).flatten())
812
+ else:
813
+ result.extend(np.asarray(phi_wv).flatten())
814
+
815
+ return np.array(result)
816
+
817
+
818
+ def _remap_jacobian(J: np.ndarray, recon: dict, cfg: dict) -> np.ndarray:
819
+ """
820
+ Remap Jacobian from forward mesh nodes to reconstruction mesh nodes.
821
+
822
+ Parameters
823
+ ----------
824
+ J : ndarray
825
+ Jacobian on forward mesh (Nsd x Nn_forward)
826
+ recon : dict
827
+ Reconstruction structure with mapid (1-based), mapweight, elem (1-based)
828
+ cfg : dict
829
+ Forward structure
830
+
831
+ Returns
832
+ -------
833
+ J_new : ndarray
834
+ Jacobian on reconstruction mesh (Nsd x Nn_recon)
835
+ """
836
+ nn_recon = recon["node"].shape[0]
837
+ nn_forward = J.shape[1]
838
+ nsd = J.shape[0]
839
+
840
+ J_new = np.zeros((nsd, nn_recon), dtype=J.dtype)
841
+
842
+ mapid = recon["mapid"] # 1-based element indices into recon mesh
843
+ mapweight = recon["mapweight"] # Barycentric coordinates (Nn_forward x 4)
844
+
845
+ # Convert 1-based elem to 0-based for numpy indexing
846
+ elem_0 = recon["elem"][:, :4].astype(int) - 1
847
+ n_elem = elem_0.shape[0]
848
+
849
+ # For each forward mesh node, distribute its Jacobian contribution
850
+ # to the reconstruction mesh nodes of the enclosing element
851
+ for i in range(nn_forward):
852
+ eid_raw = mapid[i]
853
+
854
+ # Skip NaN entries (forward node outside recon mesh)
855
+ if np.isnan(eid_raw):
856
+ continue
857
+
858
+ eid = int(eid_raw) - 1 # Convert 1-based to 0-based
859
+
860
+ # Bounds check on element index
861
+ if eid < 0 or eid >= n_elem:
862
+ continue
863
+
864
+ # Get reconstruction mesh node indices for this element
865
+ node_ids = elem_0[eid, :] # 4 node indices (0-based)
866
+
867
+ # Bounds check on node indices
868
+ if np.any(node_ids < 0) or np.any(node_ids >= nn_recon):
869
+ continue
870
+
871
+ # Distribute Jacobian contribution using barycentric weights
872
+ for j in range(4):
873
+ node_idx = node_ids[j]
874
+ J_new[:, node_idx] += J[:, i] * mapweight[i, j]
875
+
876
+ return J_new
877
+
878
+
879
+ def _masksum(data: np.ndarray, mask: np.ndarray) -> np.ndarray:
880
+ """
881
+ Sum columns by segmentation mask for label-based reconstruction.
882
+
883
+ Compresses node-based Jacobian to label-based by summing all nodes
884
+ with the same label.
885
+ """
886
+ labels = np.unique(mask)
887
+ result = np.zeros((data.shape[0], len(labels)), dtype=data.dtype)
888
+
889
+ for i, label in enumerate(labels):
890
+ idx = mask == label
891
+ result[:, i] = np.sum(data[:, idx], axis=1)
892
+
893
+ return result