redbirdpy 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- redbirdpy/__init__.py +112 -0
- redbirdpy/analytical.py +927 -0
- redbirdpy/forward.py +589 -0
- redbirdpy/property.py +602 -0
- redbirdpy/recon.py +893 -0
- redbirdpy/solver.py +814 -0
- redbirdpy/utility.py +1117 -0
- redbirdpy-0.1.0.dist-info/METADATA +596 -0
- redbirdpy-0.1.0.dist-info/RECORD +13 -0
- redbirdpy-0.1.0.dist-info/WHEEL +5 -0
- redbirdpy-0.1.0.dist-info/licenses/LICENSE.txt +674 -0
- redbirdpy-0.1.0.dist-info/top_level.txt +1 -0
- redbirdpy-0.1.0.dist-info/zip-safe +1 -0
redbirdpy/recon.py
ADDED
|
@@ -0,0 +1,893 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Redbird Reconstruction Module - Inverse problem solvers for DOT/NIRS.
|
|
3
|
+
|
|
4
|
+
INDEX CONVENTION: All mesh indices (elem, face) stored in cfg/recon are 1-based
|
|
5
|
+
to match MATLAB/iso2mesh. Conversion to 0-based occurs only when indexing numpy
|
|
6
|
+
arrays, using local variables named with '_0' suffix.
|
|
7
|
+
|
|
8
|
+
Functions:
|
|
9
|
+
runrecon: Main reconstruction driver with iterative Gauss-Newton
|
|
10
|
+
reginv: Regularized matrix inversion (auto-selects over/under-determined)
|
|
11
|
+
reginvover: Overdetermined least-squares solver
|
|
12
|
+
reginvunder: Underdetermined least-squares solver
|
|
13
|
+
matreform: Reformat matrix equation for different output forms
|
|
14
|
+
prior: Generate structure-prior regularization matrices
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"runrecon",
|
|
19
|
+
"reginv",
|
|
20
|
+
"reginvover",
|
|
21
|
+
"reginvunder",
|
|
22
|
+
"matreform",
|
|
23
|
+
"matflat",
|
|
24
|
+
"prior",
|
|
25
|
+
"syncprop",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
import numpy as np
|
|
29
|
+
from scipy import sparse
|
|
30
|
+
from scipy.sparse.linalg import spsolve
|
|
31
|
+
from typing import Dict, Tuple, Optional, Union, List, Any
|
|
32
|
+
import warnings
|
|
33
|
+
|
|
34
|
+
from .forward import runforward, jac, jacchrome
|
|
35
|
+
from .utility import sdmap, meshinterp
|
|
36
|
+
from .property import updateprop
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def runrecon(
|
|
40
|
+
cfg: dict,
|
|
41
|
+
recon: dict,
|
|
42
|
+
detphi0: Union[np.ndarray, dict],
|
|
43
|
+
sd: Union[np.ndarray, dict] = None,
|
|
44
|
+
**kwargs,
|
|
45
|
+
) -> tuple:
|
|
46
|
+
"""
|
|
47
|
+
Perform iterative Gauss-Newton reconstruction.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
cfg : dict
|
|
52
|
+
Forward simulation structure (forward mesh). elem/face are 1-based.
|
|
53
|
+
recon : dict
|
|
54
|
+
Reconstruction structure containing:
|
|
55
|
+
- node, elem: Reconstruction mesh (optional, for dual-mesh). elem is 1-based.
|
|
56
|
+
- param: Initial chromophore concentrations
|
|
57
|
+
- prop: Initial optical properties
|
|
58
|
+
- lambda: Regularization parameter
|
|
59
|
+
- bulk: Background property values
|
|
60
|
+
- mapid: Forward-to-recon mesh mapping (0-based element indices)
|
|
61
|
+
- mapweight: Barycentric weights for mapping
|
|
62
|
+
detphi0 : ndarray or dict
|
|
63
|
+
Measured data to fit
|
|
64
|
+
sd : ndarray or dict
|
|
65
|
+
Source-detector mapping (0-based column indices into phi matrix)
|
|
66
|
+
**kwargs : dict
|
|
67
|
+
Options: maxiter, lambda_, tol, reform, report, prior
|
|
68
|
+
|
|
69
|
+
Returns
|
|
70
|
+
-------
|
|
71
|
+
recon : dict
|
|
72
|
+
Updated reconstruction with fitted properties
|
|
73
|
+
resid : ndarray
|
|
74
|
+
Residual at each iteration
|
|
75
|
+
cfg : dict
|
|
76
|
+
Updated forward structure
|
|
77
|
+
"""
|
|
78
|
+
import time
|
|
79
|
+
|
|
80
|
+
# Parse options
|
|
81
|
+
maxiter = kwargs.get("maxiter", 5)
|
|
82
|
+
lambda_ = kwargs.get("lambda_", recon.get("lambda", 0.05))
|
|
83
|
+
report = kwargs.get("report", True)
|
|
84
|
+
tol = kwargs.get("tol", 0)
|
|
85
|
+
reform = kwargs.get("reform", "real")
|
|
86
|
+
solverflag = kwargs.get("solverflag", {})
|
|
87
|
+
rfcw = kwargs.get("rfcw", [1])
|
|
88
|
+
prior_type = kwargs.get("prior", "")
|
|
89
|
+
|
|
90
|
+
if isinstance(rfcw, int):
|
|
91
|
+
rfcw = [rfcw]
|
|
92
|
+
|
|
93
|
+
if sd is None:
|
|
94
|
+
sd = sdmap(cfg)
|
|
95
|
+
|
|
96
|
+
# Normalize recon["prop"] to always be 2D
|
|
97
|
+
if "prop" in recon:
|
|
98
|
+
if isinstance(recon["prop"], np.ndarray):
|
|
99
|
+
if recon["prop"].ndim == 1:
|
|
100
|
+
recon["prop"] = recon["prop"].reshape(1, -1)
|
|
101
|
+
elif isinstance(recon["prop"], dict):
|
|
102
|
+
for key in recon["prop"]:
|
|
103
|
+
if (
|
|
104
|
+
isinstance(recon["prop"][key], np.ndarray)
|
|
105
|
+
and recon["prop"][key].ndim == 1
|
|
106
|
+
):
|
|
107
|
+
recon["prop"][key] = recon["prop"][key].reshape(1, -1)
|
|
108
|
+
|
|
109
|
+
# Determine if this is label-based reconstruction
|
|
110
|
+
# Label-based: recon["prop"] has few rows (matching number of tissue labels)
|
|
111
|
+
# Node-based: recon["prop"] has many rows (matching number of nodes)
|
|
112
|
+
is_label_based = False
|
|
113
|
+
if "prop" in recon and isinstance(recon["prop"], np.ndarray):
|
|
114
|
+
n_prop_rows = recon["prop"].shape[0]
|
|
115
|
+
# If prop has fewer rows than a reasonable mesh would have nodes,
|
|
116
|
+
# it's label-based. Typical meshes have 100+ nodes.
|
|
117
|
+
if n_prop_rows < 50:
|
|
118
|
+
is_label_based = True
|
|
119
|
+
# Create seg array if not present, assuming all elements use label 1
|
|
120
|
+
if "seg" not in recon and "elem" in cfg:
|
|
121
|
+
recon["seg"] = np.ones(cfg["elem"].shape[0], dtype=int)
|
|
122
|
+
|
|
123
|
+
resid = np.zeros(maxiter)
|
|
124
|
+
updates = []
|
|
125
|
+
|
|
126
|
+
# Build regularization matrix if needed
|
|
127
|
+
Aregu = {}
|
|
128
|
+
if "lmat" in kwargs:
|
|
129
|
+
Aregu["lmat"] = kwargs["lmat"]
|
|
130
|
+
elif "ltl" in kwargs:
|
|
131
|
+
Aregu["ltl"] = kwargs["ltl"]
|
|
132
|
+
elif prior_type and "seg" in recon:
|
|
133
|
+
Aregu["lmat"] = prior(recon["seg"], prior_type, kwargs)
|
|
134
|
+
|
|
135
|
+
# Determine if using dual mesh
|
|
136
|
+
dual_mesh = "node" in recon and "elem" in recon and "mapid" in recon
|
|
137
|
+
|
|
138
|
+
# Main iteration loop
|
|
139
|
+
for iteration in range(maxiter):
|
|
140
|
+
t_start = time.time()
|
|
141
|
+
|
|
142
|
+
# Sync properties between recon and forward mesh
|
|
143
|
+
if "param" in recon or "prop" in recon:
|
|
144
|
+
cfg, recon = syncprop(cfg, recon)
|
|
145
|
+
|
|
146
|
+
# Update cfg.prop from cfg.param if multi-spectral
|
|
147
|
+
if "param" in cfg and isinstance(cfg.get("prop"), dict):
|
|
148
|
+
cfg["prop"] = updateprop(cfg)
|
|
149
|
+
|
|
150
|
+
# Run forward simulation
|
|
151
|
+
detphi, phi = runforward(cfg, solverflag=solverflag, sd=sd, rfcw=rfcw)
|
|
152
|
+
|
|
153
|
+
# Build Jacobians
|
|
154
|
+
wavelengths = [""]
|
|
155
|
+
if isinstance(cfg.get("prop"), dict):
|
|
156
|
+
wavelengths = list(cfg["prop"].keys())
|
|
157
|
+
|
|
158
|
+
Jmua = {}
|
|
159
|
+
|
|
160
|
+
for wv in wavelengths:
|
|
161
|
+
sdwv = sd.get(wv, sd) if isinstance(sd, dict) else sd
|
|
162
|
+
phiwv = phi.get(wv, phi) if isinstance(phi, dict) else phi
|
|
163
|
+
|
|
164
|
+
Jmua_n, Jmua_e = jac(
|
|
165
|
+
sdwv, phiwv, cfg["deldotdel"], cfg["elem"], cfg["evol"]
|
|
166
|
+
)
|
|
167
|
+
# Use "mua" as key for single-wavelength case
|
|
168
|
+
key = "mua" if wv == "" else wv
|
|
169
|
+
Jmua[key] = Jmua_n
|
|
170
|
+
|
|
171
|
+
# Build chromophore Jacobians if multi-spectral
|
|
172
|
+
if isinstance(cfg.get("prop"), dict) and "param" in cfg:
|
|
173
|
+
chromophores = [
|
|
174
|
+
k
|
|
175
|
+
for k in cfg["param"].keys()
|
|
176
|
+
if k in ["hbo", "hbr", "water", "lipids", "aa3"]
|
|
177
|
+
]
|
|
178
|
+
if chromophores:
|
|
179
|
+
Jmua = jacchrome(Jmua, chromophores)
|
|
180
|
+
|
|
181
|
+
# Flatten measurement data
|
|
182
|
+
detphi0_flat = _flatten_detphi(detphi0, sd, wavelengths, rfcw)
|
|
183
|
+
detphi_flat = _flatten_detphi(detphi, sd, wavelengths, rfcw)
|
|
184
|
+
|
|
185
|
+
# Get block structure
|
|
186
|
+
if isinstance(Jmua, dict):
|
|
187
|
+
blocks = {k: v.shape for k, v in Jmua.items()}
|
|
188
|
+
else:
|
|
189
|
+
blocks = {"mua": Jmua.shape}
|
|
190
|
+
|
|
191
|
+
# Flatten Jacobian
|
|
192
|
+
Jflat = matflat(Jmua)
|
|
193
|
+
|
|
194
|
+
# Reformat for real-valued solver if needed
|
|
195
|
+
if reform != "complex":
|
|
196
|
+
Jflat, misfit, nblock = matreform(Jflat, detphi0_flat, detphi_flat, reform)
|
|
197
|
+
else:
|
|
198
|
+
misfit = detphi0_flat - detphi_flat
|
|
199
|
+
|
|
200
|
+
# Map Jacobian to recon mesh if dual-mesh
|
|
201
|
+
if dual_mesh:
|
|
202
|
+
Jflat = _remap_jacobian(Jflat, recon, cfg)
|
|
203
|
+
# Update blocks to reflect recon mesh size
|
|
204
|
+
nn_recon = recon["node"].shape[0]
|
|
205
|
+
blocks = {k: (v[0], nn_recon) for k, v in blocks.items()}
|
|
206
|
+
|
|
207
|
+
# Compress for segmented reconstruction ONLY if:
|
|
208
|
+
# 1. seg array length matches Jacobian columns (node-based seg for compression)
|
|
209
|
+
# 2. Few unique labels (true label-based reconstruction, not element segmentation)
|
|
210
|
+
if "seg" in recon and np.ndim(recon["seg"]) == 1:
|
|
211
|
+
seg = recon["seg"]
|
|
212
|
+
n_jac_cols = Jflat.shape[1]
|
|
213
|
+
n_labels = len(np.unique(seg))
|
|
214
|
+
|
|
215
|
+
# Only compress if seg is node-based (matches Jac columns) with few labels
|
|
216
|
+
if len(seg) == n_jac_cols and n_labels < 50:
|
|
217
|
+
Jflat = _masksum(Jflat, seg)
|
|
218
|
+
# Update blocks to reflect compressed size
|
|
219
|
+
blocks = {k: (v[0], n_labels) for k, v in blocks.items()}
|
|
220
|
+
|
|
221
|
+
# Store residual
|
|
222
|
+
resid[iteration] = np.sum(np.abs(misfit))
|
|
223
|
+
|
|
224
|
+
# Prepare regularization
|
|
225
|
+
if iteration == 0 and Aregu:
|
|
226
|
+
if "lmat" in Aregu and "ltl" not in Aregu:
|
|
227
|
+
if Jflat.shape[0] >= Jflat.shape[1]:
|
|
228
|
+
Aregu["ltl"] = Aregu["lmat"].T @ Aregu["lmat"]
|
|
229
|
+
else:
|
|
230
|
+
from scipy.linalg import qr
|
|
231
|
+
|
|
232
|
+
_, Aregu["lir"] = qr(Aregu["lmat"])
|
|
233
|
+
Aregu["lir"] = np.linalg.inv(np.triu(Aregu["lir"]))
|
|
234
|
+
|
|
235
|
+
blockscale = 1.0 / np.sqrt(np.sum(Jflat**2))
|
|
236
|
+
Jflat = Jflat * blockscale
|
|
237
|
+
|
|
238
|
+
# Solve inverse problem
|
|
239
|
+
dmu = reginv(Jflat, misfit, lambda_, Aregu, blocks, **solverflag)
|
|
240
|
+
dmu = dmu * blockscale
|
|
241
|
+
|
|
242
|
+
# Parse update and apply to recon structure
|
|
243
|
+
update = {}
|
|
244
|
+
idx = 0
|
|
245
|
+
output_keys = list(blocks.keys())
|
|
246
|
+
|
|
247
|
+
for key in output_keys:
|
|
248
|
+
size = blocks[key][1]
|
|
249
|
+
dx = dmu[idx : idx + size]
|
|
250
|
+
update[key] = dx
|
|
251
|
+
idx += size
|
|
252
|
+
|
|
253
|
+
# Apply update to recon structure (not cfg!)
|
|
254
|
+
if key in ["mua", "dcoeff"]:
|
|
255
|
+
propidx = 0 if key == "mua" else 1
|
|
256
|
+
if "prop" in recon and isinstance(recon["prop"], np.ndarray):
|
|
257
|
+
prop = recon["prop"]
|
|
258
|
+
n_prop_rows = prop.shape[0]
|
|
259
|
+
|
|
260
|
+
# Determine if label-based by comparing prop rows to dx length
|
|
261
|
+
if n_prop_rows < len(dx) and n_prop_rows < 50:
|
|
262
|
+
# Label-based: prop has one row per tissue label
|
|
263
|
+
n_updates = min(n_prop_rows, len(dx))
|
|
264
|
+
for li in range(n_updates):
|
|
265
|
+
if key == "dcoeff":
|
|
266
|
+
old_dcoeff = 1.0 / (3 * prop[li, propidx])
|
|
267
|
+
new_dcoeff = old_dcoeff + dx[li]
|
|
268
|
+
recon["prop"][li, propidx] = 1.0 / (3 * new_dcoeff)
|
|
269
|
+
else:
|
|
270
|
+
recon["prop"][li, propidx] += dx[li]
|
|
271
|
+
else:
|
|
272
|
+
# Node/element based: prop has one row per node
|
|
273
|
+
if key == "dcoeff":
|
|
274
|
+
old_dcoeff = 1.0 / (3 * prop[:, propidx])
|
|
275
|
+
new_dcoeff = old_dcoeff + dx
|
|
276
|
+
recon["prop"][:, propidx] = 1.0 / (3 * new_dcoeff)
|
|
277
|
+
else:
|
|
278
|
+
recon["prop"][:, propidx] += dx
|
|
279
|
+
|
|
280
|
+
elif key in ["hbo", "hbr", "water", "lipids", "scatamp", "scatpow"]:
|
|
281
|
+
# Determine target: recon["param"] if present, else cfg["param"]
|
|
282
|
+
if "param" in recon and key in recon["param"]:
|
|
283
|
+
target = recon
|
|
284
|
+
elif "param" in cfg and key in cfg["param"]:
|
|
285
|
+
target = cfg
|
|
286
|
+
else:
|
|
287
|
+
continue
|
|
288
|
+
|
|
289
|
+
param_val = target["param"][key]
|
|
290
|
+
|
|
291
|
+
# Get length of parameter (scalar vs array)
|
|
292
|
+
if hasattr(param_val, "__len__"):
|
|
293
|
+
n_param = len(param_val)
|
|
294
|
+
else:
|
|
295
|
+
n_param = 1
|
|
296
|
+
|
|
297
|
+
# Determine if label-based by comparing param length to dx length
|
|
298
|
+
if n_param < len(dx) and n_param < 50:
|
|
299
|
+
# Label-based: param has one value per tissue label
|
|
300
|
+
if n_param == 1:
|
|
301
|
+
# Scalar parameter
|
|
302
|
+
if hasattr(param_val, "__len__"):
|
|
303
|
+
target["param"][key][0] += dx[0]
|
|
304
|
+
else:
|
|
305
|
+
target["param"][key] += dx[0]
|
|
306
|
+
else:
|
|
307
|
+
# Array parameter with few elements (labels)
|
|
308
|
+
n_updates = min(n_param, len(dx))
|
|
309
|
+
for li in range(n_updates):
|
|
310
|
+
target["param"][key][li] += dx[li]
|
|
311
|
+
else:
|
|
312
|
+
# Node/element based: param has one value per node
|
|
313
|
+
target["param"][key] = param_val + dx
|
|
314
|
+
|
|
315
|
+
updates.append(update)
|
|
316
|
+
|
|
317
|
+
if report:
|
|
318
|
+
elapsed = time.time() - t_start
|
|
319
|
+
rel_resid = resid[iteration] / resid[0] if iteration > 0 else 1.0
|
|
320
|
+
print(
|
|
321
|
+
f"iter [{iteration + 1:4d}]: residual={resid[iteration]:.6e}, "
|
|
322
|
+
f"relres={rel_resid:.6e} lambda={lambda_:.6e} (time={elapsed:.2f} s)"
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# Check convergence
|
|
326
|
+
if (
|
|
327
|
+
iteration > 0
|
|
328
|
+
and abs(resid[iteration] - resid[iteration - 1]) / resid[0] < tol
|
|
329
|
+
):
|
|
330
|
+
resid = resid[: iteration + 1]
|
|
331
|
+
break
|
|
332
|
+
|
|
333
|
+
recon["lambda"] = lambda_
|
|
334
|
+
|
|
335
|
+
return recon, resid, cfg, updates, Jmua, detphi, phi
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def reginv(
|
|
339
|
+
Amat: np.ndarray,
|
|
340
|
+
rhs: np.ndarray,
|
|
341
|
+
lambda_: float,
|
|
342
|
+
Areg: dict = None,
|
|
343
|
+
blocks: dict = None,
|
|
344
|
+
**kwargs,
|
|
345
|
+
) -> np.ndarray:
|
|
346
|
+
"""
|
|
347
|
+
Solve regularized linear system, auto-selecting method.
|
|
348
|
+
|
|
349
|
+
Automatically chooses overdetermined or underdetermined solver
|
|
350
|
+
based on matrix dimensions.
|
|
351
|
+
"""
|
|
352
|
+
if Areg is None:
|
|
353
|
+
Areg = {}
|
|
354
|
+
|
|
355
|
+
if Amat.shape[0] >= Amat.shape[1]:
|
|
356
|
+
LTL = Areg.get("ltl", None)
|
|
357
|
+
return reginvover(Amat, rhs, lambda_, LTL, blocks, **kwargs)
|
|
358
|
+
else:
|
|
359
|
+
invR = Areg.get("lir", None)
|
|
360
|
+
return reginvunder(Amat, rhs, lambda_, invR, blocks, **kwargs)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def reginvover(
|
|
364
|
+
Amat: np.ndarray,
|
|
365
|
+
rhs: np.ndarray,
|
|
366
|
+
lambda_: float,
|
|
367
|
+
LTL: np.ndarray = None,
|
|
368
|
+
blocks: dict = None,
|
|
369
|
+
**kwargs,
|
|
370
|
+
) -> np.ndarray:
|
|
371
|
+
"""
|
|
372
|
+
Solve overdetermined Gauss-Newton normal equation.
|
|
373
|
+
|
|
374
|
+
Solves: delta_mu = inv(J'J + lambda*(L'L)) * J' * (y - phi)
|
|
375
|
+
"""
|
|
376
|
+
# Remove zero-sensitivity columns
|
|
377
|
+
col_sum = np.sum(np.abs(Amat), axis=0)
|
|
378
|
+
idx0 = np.where(col_sum != 0)[0]
|
|
379
|
+
length0 = Amat.shape[1]
|
|
380
|
+
|
|
381
|
+
if len(idx0) < length0:
|
|
382
|
+
Amat = Amat[:, idx0]
|
|
383
|
+
if LTL is not None and LTL.shape[0] > len(idx0):
|
|
384
|
+
Lidx = idx0[idx0 < LTL.shape[0]]
|
|
385
|
+
LTL = LTL[np.ix_(Lidx, Lidx)]
|
|
386
|
+
|
|
387
|
+
# Remove zero-data rows
|
|
388
|
+
row_sum = np.sum(np.abs(Amat), axis=1)
|
|
389
|
+
valid_rows = row_sum != 0
|
|
390
|
+
if np.sum(valid_rows) < Amat.shape[0]:
|
|
391
|
+
Amat = Amat[valid_rows, :]
|
|
392
|
+
rhs = rhs[valid_rows]
|
|
393
|
+
|
|
394
|
+
# Build normal equation
|
|
395
|
+
rhs_proj = Amat.T @ rhs.flatten()
|
|
396
|
+
Hess = Amat.T @ Amat
|
|
397
|
+
|
|
398
|
+
# Add regularization
|
|
399
|
+
if LTL is None:
|
|
400
|
+
Hess[np.diag_indices_from(Hess)] += lambda_
|
|
401
|
+
else:
|
|
402
|
+
if Hess.shape[0] == LTL.shape[0]:
|
|
403
|
+
Hess = Hess + lambda_ * LTL
|
|
404
|
+
else:
|
|
405
|
+
nx = LTL.shape[0]
|
|
406
|
+
for i in range(0, Hess.shape[0], nx):
|
|
407
|
+
end_i = min(i + nx, Hess.shape[0])
|
|
408
|
+
Hess[i:end_i, i:end_i] = (
|
|
409
|
+
Hess[i:end_i, i:end_i] + lambda_ * LTL[: end_i - i, : end_i - i]
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
# Normalize and solve
|
|
413
|
+
Hess_norm, Gdiag = _normalize_diag(Hess)
|
|
414
|
+
|
|
415
|
+
if sparse.issparse(Hess_norm):
|
|
416
|
+
res = Gdiag * spsolve(Hess_norm, Gdiag * rhs_proj)
|
|
417
|
+
else:
|
|
418
|
+
res = Gdiag * np.linalg.solve(Hess_norm, Gdiag * rhs_proj)
|
|
419
|
+
|
|
420
|
+
# Restore full-length result
|
|
421
|
+
if len(idx0) < length0:
|
|
422
|
+
res_full = np.zeros(length0)
|
|
423
|
+
res_full[idx0] = res
|
|
424
|
+
res = res_full
|
|
425
|
+
|
|
426
|
+
return res
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def reginvunder(
|
|
430
|
+
Amat: np.ndarray,
|
|
431
|
+
rhs: np.ndarray,
|
|
432
|
+
lambda_: float,
|
|
433
|
+
invR: np.ndarray = None,
|
|
434
|
+
blocks: dict = None,
|
|
435
|
+
**kwargs,
|
|
436
|
+
) -> np.ndarray:
|
|
437
|
+
"""
|
|
438
|
+
Solve underdetermined Gauss-Newton equation.
|
|
439
|
+
|
|
440
|
+
Solves: delta_mu = inv(L'L)*J'*inv(J*inv(L'L)*J' + lambda*I)*(y-phi)
|
|
441
|
+
"""
|
|
442
|
+
Alen = Amat.shape[1]
|
|
443
|
+
|
|
444
|
+
# Remove zero columns
|
|
445
|
+
col_sum = np.sum(np.abs(Amat), axis=0)
|
|
446
|
+
idx = np.where(col_sum != 0)[0]
|
|
447
|
+
if len(idx) < Alen:
|
|
448
|
+
Amat = Amat[:, idx]
|
|
449
|
+
|
|
450
|
+
# Remove zero rows
|
|
451
|
+
row_sum = np.sum(np.abs(Amat), axis=1)
|
|
452
|
+
valid_rows = row_sum != 0
|
|
453
|
+
if np.sum(valid_rows) < Amat.shape[0]:
|
|
454
|
+
Amat = Amat[valid_rows, :]
|
|
455
|
+
rhs = rhs[valid_rows]
|
|
456
|
+
|
|
457
|
+
# Apply regularization transform
|
|
458
|
+
if invR is not None:
|
|
459
|
+
nx = invR.shape[0]
|
|
460
|
+
if nx == Amat.shape[1]:
|
|
461
|
+
Amat = Amat @ invR
|
|
462
|
+
elif blocks is not None:
|
|
463
|
+
block_keys = list(blocks.keys())
|
|
464
|
+
cumlen = np.cumsum([0] + [blocks[k][1] for k in block_keys])
|
|
465
|
+
for i, k in enumerate(block_keys):
|
|
466
|
+
if cumlen[i + 1] - cumlen[i] == nx:
|
|
467
|
+
Amat[:, cumlen[i] : cumlen[i + 1]] = (
|
|
468
|
+
Amat[:, cumlen[i] : cumlen[i + 1]] @ invR
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
rhs = rhs.flatten()
|
|
472
|
+
|
|
473
|
+
# Build Hessian in dual space
|
|
474
|
+
Hess = Amat @ Amat.T
|
|
475
|
+
Hess[np.diag_indices_from(Hess)] += lambda_
|
|
476
|
+
|
|
477
|
+
# Normalize and solve
|
|
478
|
+
Hess_norm, Gdiag = _normalize_diag(Hess)
|
|
479
|
+
|
|
480
|
+
if sparse.issparse(Hess_norm):
|
|
481
|
+
y = Gdiag * spsolve(Hess_norm, Gdiag * rhs)
|
|
482
|
+
else:
|
|
483
|
+
y = Gdiag * np.linalg.solve(Hess_norm, Gdiag * rhs)
|
|
484
|
+
|
|
485
|
+
# Transform back to primal space
|
|
486
|
+
if invR is not None:
|
|
487
|
+
nx = invR.shape[0]
|
|
488
|
+
if nx == Amat.shape[1]:
|
|
489
|
+
res = invR @ (Amat.T @ y)
|
|
490
|
+
else:
|
|
491
|
+
res = Amat.T @ y
|
|
492
|
+
if blocks is not None:
|
|
493
|
+
block_keys = list(blocks.keys())
|
|
494
|
+
cumlen = np.cumsum([0] + [blocks[k][1] for k in block_keys])
|
|
495
|
+
for i, k in enumerate(block_keys):
|
|
496
|
+
if cumlen[i + 1] - cumlen[i] == nx:
|
|
497
|
+
res[cumlen[i] : cumlen[i + 1]] = (
|
|
498
|
+
invR @ res[cumlen[i] : cumlen[i + 1]]
|
|
499
|
+
)
|
|
500
|
+
else:
|
|
501
|
+
res = Amat.T @ y
|
|
502
|
+
|
|
503
|
+
# Restore full length
|
|
504
|
+
if len(idx) < Alen:
|
|
505
|
+
res_full = np.zeros(Alen)
|
|
506
|
+
res_full[idx] = res
|
|
507
|
+
res = res_full
|
|
508
|
+
|
|
509
|
+
return res
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def matreform(
|
|
513
|
+
Amat: np.ndarray, ymeas: np.ndarray, ymodel: np.ndarray, form: str = "complex"
|
|
514
|
+
) -> Tuple[np.ndarray, np.ndarray, int]:
|
|
515
|
+
"""
|
|
516
|
+
Reformat matrix equation for different output forms.
|
|
517
|
+
|
|
518
|
+
Parameters
|
|
519
|
+
----------
|
|
520
|
+
form : str
|
|
521
|
+
'complex': No transformation
|
|
522
|
+
'real': Real-valued system
|
|
523
|
+
'reim': Expand to [Re(x); Im(x)]
|
|
524
|
+
'logphase': Log-amplitude and phase form
|
|
525
|
+
"""
|
|
526
|
+
nblock = 1
|
|
527
|
+
rhs = ymeas - ymodel
|
|
528
|
+
|
|
529
|
+
if form == "complex":
|
|
530
|
+
return Amat, rhs, nblock
|
|
531
|
+
|
|
532
|
+
if form in ["real", "reim"]:
|
|
533
|
+
newA = np.real(Amat)
|
|
534
|
+
newrhs = np.real(rhs)
|
|
535
|
+
|
|
536
|
+
if not np.isreal(rhs).all() and not np.isreal(Amat).all():
|
|
537
|
+
if form == "reim":
|
|
538
|
+
newA = np.block(
|
|
539
|
+
[[np.real(Amat), -np.imag(Amat)], [np.imag(Amat), np.real(Amat)]]
|
|
540
|
+
)
|
|
541
|
+
else:
|
|
542
|
+
newA = np.vstack([np.real(Amat), np.imag(Amat)])
|
|
543
|
+
newrhs = np.concatenate([np.real(rhs), np.imag(rhs)])
|
|
544
|
+
nblock = 2
|
|
545
|
+
|
|
546
|
+
return newA, newrhs, nblock
|
|
547
|
+
|
|
548
|
+
if form == "logphase":
|
|
549
|
+
temp = np.conj(ymodel) / np.abs(ymodel * ymodel)
|
|
550
|
+
temp = temp[:, np.newaxis] * Amat if Amat.ndim == 2 else temp * Amat
|
|
551
|
+
|
|
552
|
+
if np.isreal(ymodel).all():
|
|
553
|
+
newA = np.real(temp)
|
|
554
|
+
newrhs = np.log(np.abs(ymeas)) - np.log(np.abs(ymodel))
|
|
555
|
+
else:
|
|
556
|
+
newA = np.vstack([np.real(temp), np.imag(temp)])
|
|
557
|
+
newrhs = np.concatenate(
|
|
558
|
+
[
|
|
559
|
+
np.log(np.abs(ymeas)) - np.log(np.abs(ymodel)),
|
|
560
|
+
np.angle(ymeas) - np.angle(ymodel),
|
|
561
|
+
]
|
|
562
|
+
)
|
|
563
|
+
nblock = 2
|
|
564
|
+
|
|
565
|
+
return newA, newrhs, nblock
|
|
566
|
+
|
|
567
|
+
raise ValueError(f"Unknown form: {form}")
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
def matflat(Amat: Union[dict, np.ndarray], weight: np.ndarray = None) -> np.ndarray:
|
|
571
|
+
"""Flatten dict of matrices into single 2D matrix."""
|
|
572
|
+
if isinstance(Amat, np.ndarray):
|
|
573
|
+
return Amat
|
|
574
|
+
|
|
575
|
+
if isinstance(Amat, dict):
|
|
576
|
+
keys = list(Amat.keys())
|
|
577
|
+
if weight is None:
|
|
578
|
+
weight = np.ones(len(keys))
|
|
579
|
+
|
|
580
|
+
first_val = Amat[keys[0]]
|
|
581
|
+
if isinstance(first_val, dict):
|
|
582
|
+
# Multi-wavelength: vertically concatenate
|
|
583
|
+
inner_keys = list(first_val.keys())
|
|
584
|
+
Anew = []
|
|
585
|
+
for wv in inner_keys:
|
|
586
|
+
row = np.hstack([Amat[k][wv] * weight[j] for j, k in enumerate(keys)])
|
|
587
|
+
Anew.append(row)
|
|
588
|
+
return np.vstack(Anew)
|
|
589
|
+
else:
|
|
590
|
+
# Single wavelength: horizontally concatenate
|
|
591
|
+
return np.hstack([Amat[k] * weight[i] for i, k in enumerate(keys)])
|
|
592
|
+
|
|
593
|
+
return Amat
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
def prior(seg: np.ndarray, priortype: str, params: dict = None) -> np.ndarray:
|
|
597
|
+
"""
|
|
598
|
+
Generate structure-prior regularization matrix.
|
|
599
|
+
|
|
600
|
+
Parameters
|
|
601
|
+
----------
|
|
602
|
+
seg : ndarray
|
|
603
|
+
Segmentation labels (node or element based) or composition matrix
|
|
604
|
+
priortype : str
|
|
605
|
+
'laplace': Laplacian prior within segments
|
|
606
|
+
'helmholtz': Helmholtz-like prior with beta parameter
|
|
607
|
+
'comp': Compositional prior for soft segmentation
|
|
608
|
+
"""
|
|
609
|
+
if not priortype:
|
|
610
|
+
return None
|
|
611
|
+
|
|
612
|
+
params = params or {}
|
|
613
|
+
|
|
614
|
+
if np.ndim(seg) == 1:
|
|
615
|
+
# Label-based segmentation
|
|
616
|
+
labels, inverse = np.unique(seg, return_inverse=True)
|
|
617
|
+
counts = np.bincount(inverse)
|
|
618
|
+
n = len(seg)
|
|
619
|
+
|
|
620
|
+
if priortype == "laplace":
|
|
621
|
+
Lmat = np.eye(n)
|
|
622
|
+
for i, label in enumerate(labels):
|
|
623
|
+
idx = np.where(inverse == i)[0]
|
|
624
|
+
if counts[i] > 1:
|
|
625
|
+
Lmat[np.ix_(idx, idx)] = -1.0 / counts[i]
|
|
626
|
+
np.fill_diagonal(Lmat, 1.0)
|
|
627
|
+
return Lmat
|
|
628
|
+
|
|
629
|
+
elif priortype == "helmholtz":
|
|
630
|
+
beta = params.get("beta", 1.0)
|
|
631
|
+
Lmat = np.eye(n)
|
|
632
|
+
for i, label in enumerate(labels):
|
|
633
|
+
idx = np.where(inverse == i)[0]
|
|
634
|
+
if counts[i] > 1:
|
|
635
|
+
Lmat[np.ix_(idx, idx)] = -1.0 / (counts[i] + beta)
|
|
636
|
+
np.fill_diagonal(Lmat, 1.0)
|
|
637
|
+
return Lmat
|
|
638
|
+
|
|
639
|
+
elif priortype == "comp" and seg.ndim == 2:
|
|
640
|
+
# Compositional prior for soft segmentation
|
|
641
|
+
alpha = params.get("alpha", 0.1)
|
|
642
|
+
beta = params.get("beta", 1.0)
|
|
643
|
+
n = seg.shape[0]
|
|
644
|
+
nc = seg.shape[1]
|
|
645
|
+
|
|
646
|
+
Lmat = sparse.lil_matrix((n, n))
|
|
647
|
+
|
|
648
|
+
for i in range(n):
|
|
649
|
+
for j in range(i + 1, n):
|
|
650
|
+
dval = np.sum(np.abs(seg[i, :] - seg[j, :]))
|
|
651
|
+
if dval < alpha * nc:
|
|
652
|
+
val = -alpha - dval / nc
|
|
653
|
+
Lmat[i, j] = val
|
|
654
|
+
Lmat[j, i] = val
|
|
655
|
+
|
|
656
|
+
# Normalize rows
|
|
657
|
+
rowsum = np.abs(np.array(Lmat.sum(axis=1)).flatten())
|
|
658
|
+
for i in range(n):
|
|
659
|
+
for j in range(n):
|
|
660
|
+
if Lmat[i, j] != 0 and i != j:
|
|
661
|
+
Lmat[i, j] /= beta * np.sqrt(rowsum[i] * rowsum[j] + 1e-16)
|
|
662
|
+
|
|
663
|
+
Lmat = Lmat + sparse.eye(n)
|
|
664
|
+
return Lmat.tocsr()
|
|
665
|
+
|
|
666
|
+
return None
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
def syncprop(cfg: dict, recon: dict) -> Tuple[dict, dict]:
|
|
670
|
+
"""
|
|
671
|
+
Synchronize properties between forward and reconstruction meshes.
|
|
672
|
+
|
|
673
|
+
Handles both single-mesh and dual-mesh reconstruction scenarios.
|
|
674
|
+
|
|
675
|
+
For dual-mesh reconstruction:
|
|
676
|
+
- recon mesh is typically coarser than forward mesh
|
|
677
|
+
- mapid/mapweight map FORWARD mesh nodes to RECON mesh elements
|
|
678
|
+
- We interpolate from recon mesh to forward mesh
|
|
679
|
+
|
|
680
|
+
mapid contains 1-based element indices into recon["elem"].
|
|
681
|
+
"""
|
|
682
|
+
# Use iso2mesh's meshinterp for interpolation
|
|
683
|
+
try:
|
|
684
|
+
from iso2mesh import meshinterp
|
|
685
|
+
except ImportError:
|
|
686
|
+
from .utility import meshinterp
|
|
687
|
+
|
|
688
|
+
# Determine mesh sizes
|
|
689
|
+
cfg_nn = cfg["node"].shape[0]
|
|
690
|
+
cfg_ne = cfg["elem"].shape[0]
|
|
691
|
+
|
|
692
|
+
if "node" in recon and "elem" in recon:
|
|
693
|
+
recon_nn = recon["node"].shape[0]
|
|
694
|
+
recon_ne = recon["elem"].shape[0]
|
|
695
|
+
else:
|
|
696
|
+
recon_nn = cfg_nn
|
|
697
|
+
recon_ne = cfg_ne
|
|
698
|
+
|
|
699
|
+
# Threshold to distinguish label-based from node/element-based
|
|
700
|
+
# Use a small number that's clearly less than any reasonable mesh size
|
|
701
|
+
label_threshold = 50
|
|
702
|
+
|
|
703
|
+
if "param" in recon:
|
|
704
|
+
# Map recon.param to cfg.param
|
|
705
|
+
allkeys = list(recon["param"].keys())
|
|
706
|
+
first_param = recon["param"][allkeys[0]]
|
|
707
|
+
param_len = len(first_param) if hasattr(first_param, "__len__") else 1
|
|
708
|
+
|
|
709
|
+
if param_len < label_threshold:
|
|
710
|
+
# Label-based - direct copy (no interpolation needed)
|
|
711
|
+
cfg["param"] = {
|
|
712
|
+
k: v.copy() if hasattr(v, "copy") else v
|
|
713
|
+
for k, v in recon["param"].items()
|
|
714
|
+
}
|
|
715
|
+
else:
|
|
716
|
+
# Node/element based - need interpolation for dual-mesh
|
|
717
|
+
if "param" not in cfg:
|
|
718
|
+
cfg["param"] = {}
|
|
719
|
+
|
|
720
|
+
for key in allkeys:
|
|
721
|
+
if "mapid" in recon and "mapweight" in recon and "elem" in recon:
|
|
722
|
+
# Interpolate from recon mesh to forward mesh nodes
|
|
723
|
+
# Result should have cfg_nn rows - pass None for toval
|
|
724
|
+
cfg["param"][key] = meshinterp(
|
|
725
|
+
recon["param"][key],
|
|
726
|
+
recon["mapid"],
|
|
727
|
+
recon["mapweight"],
|
|
728
|
+
recon["elem"], # 1-based, meshinterp converts
|
|
729
|
+
None, # Create new array of correct size
|
|
730
|
+
)
|
|
731
|
+
else:
|
|
732
|
+
# Same mesh - direct copy
|
|
733
|
+
cfg["param"][key] = recon["param"][key].copy()
|
|
734
|
+
|
|
735
|
+
elif "prop" in recon:
|
|
736
|
+
# Map recon.prop to cfg.prop
|
|
737
|
+
if not isinstance(recon["prop"], dict):
|
|
738
|
+
recon_prop_len = recon["prop"].shape[0]
|
|
739
|
+
|
|
740
|
+
if recon_prop_len < label_threshold:
|
|
741
|
+
# Label-based recon prop - direct copy
|
|
742
|
+
cfg["prop"] = recon["prop"].copy()
|
|
743
|
+
elif "mapid" in recon and "mapweight" in recon and "elem" in recon:
|
|
744
|
+
# Node-based recon prop with dual mesh - interpolate to forward mesh
|
|
745
|
+
# The result should be node-based on the FORWARD mesh (cfg_nn rows)
|
|
746
|
+
# Pass None for toval to create new array of correct size
|
|
747
|
+
cfg["prop"] = meshinterp(
|
|
748
|
+
recon["prop"],
|
|
749
|
+
recon["mapid"],
|
|
750
|
+
recon["mapweight"],
|
|
751
|
+
recon["elem"],
|
|
752
|
+
None, # Don't pass cfg["prop"] - it may be label-based
|
|
753
|
+
)
|
|
754
|
+
else:
|
|
755
|
+
# Same mesh or no mapping - direct copy
|
|
756
|
+
cfg["prop"] = recon["prop"].copy()
|
|
757
|
+
else:
|
|
758
|
+
# Multi-wavelength
|
|
759
|
+
allkeys = list(recon["prop"].keys())
|
|
760
|
+
first_prop = recon["prop"][allkeys[0]]
|
|
761
|
+
recon_prop_len = first_prop.shape[0]
|
|
762
|
+
|
|
763
|
+
if recon_prop_len < label_threshold:
|
|
764
|
+
# Label-based - direct copy
|
|
765
|
+
cfg["prop"] = {k: v.copy() for k, v in recon["prop"].items()}
|
|
766
|
+
elif "mapid" in recon and "mapweight" in recon and "elem" in recon:
|
|
767
|
+
# Node-based with dual mesh - interpolate
|
|
768
|
+
cfg["prop"] = {}
|
|
769
|
+
for k in allkeys:
|
|
770
|
+
cfg["prop"][k] = meshinterp(
|
|
771
|
+
recon["prop"][k],
|
|
772
|
+
recon["mapid"],
|
|
773
|
+
recon["mapweight"],
|
|
774
|
+
recon["elem"],
|
|
775
|
+
None, # Create new array of correct size
|
|
776
|
+
)
|
|
777
|
+
else:
|
|
778
|
+
# Same mesh - direct copy
|
|
779
|
+
cfg["prop"] = {k: v.copy() for k, v in recon["prop"].items()}
|
|
780
|
+
|
|
781
|
+
return cfg, recon
|
|
782
|
+
|
|
783
|
+
|
|
784
|
+
def _normalize_diag(A: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
|
785
|
+
"""Normalize matrix to have unit diagonal for better conditioning."""
|
|
786
|
+
Adiag = np.diag(A)
|
|
787
|
+
di = 1.0 / np.sqrt(np.abs(Adiag) + 1e-16)
|
|
788
|
+
Anorm = (di[:, np.newaxis] * di[np.newaxis, :]) * A
|
|
789
|
+
return Anorm, di
|
|
790
|
+
|
|
791
|
+
|
|
792
|
+
def _flatten_detphi(
|
|
793
|
+
detphi: Union[np.ndarray, dict],
|
|
794
|
+
sd: Union[np.ndarray, dict],
|
|
795
|
+
wavelengths: List[str],
|
|
796
|
+
rfcw: List[int],
|
|
797
|
+
) -> np.ndarray:
|
|
798
|
+
"""Flatten detector measurements from nested dict to 1D array."""
|
|
799
|
+
if isinstance(detphi, np.ndarray):
|
|
800
|
+
return detphi.flatten()
|
|
801
|
+
|
|
802
|
+
result = []
|
|
803
|
+
for wv in wavelengths:
|
|
804
|
+
if isinstance(detphi, dict):
|
|
805
|
+
phi_wv = detphi.get(wv, detphi)
|
|
806
|
+
else:
|
|
807
|
+
phi_wv = detphi
|
|
808
|
+
|
|
809
|
+
if isinstance(phi_wv, dict):
|
|
810
|
+
for md in rfcw:
|
|
811
|
+
result.extend(phi_wv.get(md, {}).get("detphi", phi_wv).flatten())
|
|
812
|
+
else:
|
|
813
|
+
result.extend(np.asarray(phi_wv).flatten())
|
|
814
|
+
|
|
815
|
+
return np.array(result)
|
|
816
|
+
|
|
817
|
+
|
|
818
|
+
def _remap_jacobian(J: np.ndarray, recon: dict, cfg: dict) -> np.ndarray:
|
|
819
|
+
"""
|
|
820
|
+
Remap Jacobian from forward mesh nodes to reconstruction mesh nodes.
|
|
821
|
+
|
|
822
|
+
Parameters
|
|
823
|
+
----------
|
|
824
|
+
J : ndarray
|
|
825
|
+
Jacobian on forward mesh (Nsd x Nn_forward)
|
|
826
|
+
recon : dict
|
|
827
|
+
Reconstruction structure with mapid (1-based), mapweight, elem (1-based)
|
|
828
|
+
cfg : dict
|
|
829
|
+
Forward structure
|
|
830
|
+
|
|
831
|
+
Returns
|
|
832
|
+
-------
|
|
833
|
+
J_new : ndarray
|
|
834
|
+
Jacobian on reconstruction mesh (Nsd x Nn_recon)
|
|
835
|
+
"""
|
|
836
|
+
nn_recon = recon["node"].shape[0]
|
|
837
|
+
nn_forward = J.shape[1]
|
|
838
|
+
nsd = J.shape[0]
|
|
839
|
+
|
|
840
|
+
J_new = np.zeros((nsd, nn_recon), dtype=J.dtype)
|
|
841
|
+
|
|
842
|
+
mapid = recon["mapid"] # 1-based element indices into recon mesh
|
|
843
|
+
mapweight = recon["mapweight"] # Barycentric coordinates (Nn_forward x 4)
|
|
844
|
+
|
|
845
|
+
# Convert 1-based elem to 0-based for numpy indexing
|
|
846
|
+
elem_0 = recon["elem"][:, :4].astype(int) - 1
|
|
847
|
+
n_elem = elem_0.shape[0]
|
|
848
|
+
|
|
849
|
+
# For each forward mesh node, distribute its Jacobian contribution
|
|
850
|
+
# to the reconstruction mesh nodes of the enclosing element
|
|
851
|
+
for i in range(nn_forward):
|
|
852
|
+
eid_raw = mapid[i]
|
|
853
|
+
|
|
854
|
+
# Skip NaN entries (forward node outside recon mesh)
|
|
855
|
+
if np.isnan(eid_raw):
|
|
856
|
+
continue
|
|
857
|
+
|
|
858
|
+
eid = int(eid_raw) - 1 # Convert 1-based to 0-based
|
|
859
|
+
|
|
860
|
+
# Bounds check on element index
|
|
861
|
+
if eid < 0 or eid >= n_elem:
|
|
862
|
+
continue
|
|
863
|
+
|
|
864
|
+
# Get reconstruction mesh node indices for this element
|
|
865
|
+
node_ids = elem_0[eid, :] # 4 node indices (0-based)
|
|
866
|
+
|
|
867
|
+
# Bounds check on node indices
|
|
868
|
+
if np.any(node_ids < 0) or np.any(node_ids >= nn_recon):
|
|
869
|
+
continue
|
|
870
|
+
|
|
871
|
+
# Distribute Jacobian contribution using barycentric weights
|
|
872
|
+
for j in range(4):
|
|
873
|
+
node_idx = node_ids[j]
|
|
874
|
+
J_new[:, node_idx] += J[:, i] * mapweight[i, j]
|
|
875
|
+
|
|
876
|
+
return J_new
|
|
877
|
+
|
|
878
|
+
|
|
879
|
+
def _masksum(data: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
|
880
|
+
"""
|
|
881
|
+
Sum columns by segmentation mask for label-based reconstruction.
|
|
882
|
+
|
|
883
|
+
Compresses node-based Jacobian to label-based by summing all nodes
|
|
884
|
+
with the same label.
|
|
885
|
+
"""
|
|
886
|
+
labels = np.unique(mask)
|
|
887
|
+
result = np.zeros((data.shape[0], len(labels)), dtype=data.dtype)
|
|
888
|
+
|
|
889
|
+
for i, label in enumerate(labels):
|
|
890
|
+
idx = mask == label
|
|
891
|
+
result[:, i] = np.sum(data[:, idx], axis=1)
|
|
892
|
+
|
|
893
|
+
return result
|