redbirdpy 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- redbirdpy/__init__.py +112 -0
- redbirdpy/analytical.py +927 -0
- redbirdpy/forward.py +589 -0
- redbirdpy/property.py +602 -0
- redbirdpy/recon.py +893 -0
- redbirdpy/solver.py +814 -0
- redbirdpy/utility.py +1117 -0
- redbirdpy-0.1.0.dist-info/METADATA +596 -0
- redbirdpy-0.1.0.dist-info/RECORD +13 -0
- redbirdpy-0.1.0.dist-info/WHEEL +5 -0
- redbirdpy-0.1.0.dist-info/licenses/LICENSE.txt +674 -0
- redbirdpy-0.1.0.dist-info/top_level.txt +1 -0
- redbirdpy-0.1.0.dist-info/zip-safe +1 -0
redbirdpy/forward.py
ADDED
|
@@ -0,0 +1,589 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Redbird Forward Module - FEM-based forward modeling for diffuse optics.
|
|
3
|
+
|
|
4
|
+
INDEX CONVENTION: All mesh indices (elem, face) stored in cfg are 1-based
|
|
5
|
+
to match MATLAB/iso2mesh. This module converts to 0-based internally when
|
|
6
|
+
indexing numpy arrays, using local variables named with '_0' suffix.
|
|
7
|
+
|
|
8
|
+
Functions:
|
|
9
|
+
runforward: Main forward solver for all sources/wavelengths
|
|
10
|
+
femlhs: Build FEM left-hand-side (stiffness) matrix
|
|
11
|
+
femrhs: Build FEM right-hand-side vector
|
|
12
|
+
femgetdet: Extract detector values from forward solution
|
|
13
|
+
jac: Compute Jacobian matrices using adjoint method
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"runforward",
|
|
18
|
+
"femlhs",
|
|
19
|
+
"femrhs",
|
|
20
|
+
"femgetdet",
|
|
21
|
+
"jac",
|
|
22
|
+
"jacchrome",
|
|
23
|
+
"C0",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
from scipy import sparse
|
|
28
|
+
from typing import Dict, Tuple, Optional, Union, List, Any
|
|
29
|
+
|
|
30
|
+
# Import solver functions from solver module
|
|
31
|
+
from .solver import femsolve
|
|
32
|
+
from .utility import sdmap, getoptodes, deldotdel
|
|
33
|
+
from .property import extinction
|
|
34
|
+
|
|
35
|
+
# Speed of light in mm/s
|
|
36
|
+
C0 = 299792458000.0
|
|
37
|
+
R_C0 = 1.0 / C0
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def runforward(cfg: dict, **kwargs) -> Tuple[Any, Any]:
|
|
41
|
+
"""
|
|
42
|
+
Perform forward simulations at all sources and all wavelengths.
|
|
43
|
+
"""
|
|
44
|
+
solverflag = kwargs.get("solverflag", {})
|
|
45
|
+
rfcw = kwargs.get("rfcw", [1])
|
|
46
|
+
if isinstance(rfcw, int):
|
|
47
|
+
rfcw = [rfcw]
|
|
48
|
+
|
|
49
|
+
if "deldotdel" not in cfg or cfg["deldotdel"] is None:
|
|
50
|
+
cfg["deldotdel"], _ = deldotdel(cfg)
|
|
51
|
+
|
|
52
|
+
wavelengths = [""]
|
|
53
|
+
if isinstance(cfg.get("prop"), dict):
|
|
54
|
+
wavelengths = list(cfg["prop"].keys())
|
|
55
|
+
|
|
56
|
+
sd = kwargs.get("sd")
|
|
57
|
+
if sd is None:
|
|
58
|
+
sd = sdmap(cfg)
|
|
59
|
+
if not isinstance(sd, dict):
|
|
60
|
+
sd = {wv: sd for wv in wavelengths}
|
|
61
|
+
|
|
62
|
+
Amat = {}
|
|
63
|
+
detval_out = {md: {"detphi": {}} for md in rfcw}
|
|
64
|
+
phi_out = {md: {"phi": {}} for md in rfcw}
|
|
65
|
+
|
|
66
|
+
for wv in wavelengths:
|
|
67
|
+
for md in rfcw:
|
|
68
|
+
rhs, loc, bary, optode = femrhs(cfg, sd, wv, md)
|
|
69
|
+
Amat[wv] = femlhs(cfg, cfg["deldotdel"], wv, md)
|
|
70
|
+
phi_sol, flag = femsolve(Amat[wv], rhs, **kwargs)
|
|
71
|
+
phi_out[md]["phi"][wv] = phi_sol
|
|
72
|
+
|
|
73
|
+
# Pass rhs to femgetdet for wide-field detection
|
|
74
|
+
detval = femgetdet(phi_sol, cfg, rhs, loc, bary)
|
|
75
|
+
detval_out[md]["detphi"][wv] = detval
|
|
76
|
+
|
|
77
|
+
if len(wavelengths) == 1:
|
|
78
|
+
Amat = Amat[wavelengths[0]]
|
|
79
|
+
for md in rfcw:
|
|
80
|
+
phi_out[md]["phi"] = phi_out[md]["phi"][wavelengths[0]]
|
|
81
|
+
detval_out[md]["detphi"] = detval_out[md]["detphi"][wavelengths[0]]
|
|
82
|
+
|
|
83
|
+
if len(rfcw) == 1:
|
|
84
|
+
phi_out = phi_out[rfcw[0]]["phi"]
|
|
85
|
+
detval_out = detval_out[rfcw[0]]["detphi"]
|
|
86
|
+
|
|
87
|
+
return detval_out, phi_out
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def femlhs(
|
|
91
|
+
cfg: dict, deldotdel_mat: np.ndarray, wavelength: str = "", mode: int = 1
|
|
92
|
+
) -> sparse.csr_matrix:
|
|
93
|
+
"""
|
|
94
|
+
Create FEM stiffness matrix - optimized assembly with original algorithm.
|
|
95
|
+
"""
|
|
96
|
+
nn = cfg["node"].shape[0]
|
|
97
|
+
ne = cfg["elem"].shape[0]
|
|
98
|
+
evol = cfg["evol"]
|
|
99
|
+
area = cfg["area"]
|
|
100
|
+
|
|
101
|
+
# Convert 1-based to 0-based
|
|
102
|
+
elem_0 = cfg["elem"][:, :4].astype(np.int32) - 1
|
|
103
|
+
face_0 = cfg["face"].astype(np.int32) - 1
|
|
104
|
+
|
|
105
|
+
# Get properties for current wavelength
|
|
106
|
+
if isinstance(cfg.get("prop"), dict) and wavelength:
|
|
107
|
+
props = cfg["prop"][wavelength]
|
|
108
|
+
reff = (
|
|
109
|
+
cfg["reff"][wavelength]
|
|
110
|
+
if isinstance(cfg.get("reff"), dict)
|
|
111
|
+
else cfg["reff"]
|
|
112
|
+
)
|
|
113
|
+
omega = (
|
|
114
|
+
cfg["omega"].get(wavelength, 0)
|
|
115
|
+
if isinstance(cfg.get("omega"), dict)
|
|
116
|
+
else cfg.get("omega", 0)
|
|
117
|
+
)
|
|
118
|
+
else:
|
|
119
|
+
props = cfg["prop"]
|
|
120
|
+
reff = cfg.get("reff", 0.493)
|
|
121
|
+
omega = cfg.get("omega", 0)
|
|
122
|
+
|
|
123
|
+
if mode == 2:
|
|
124
|
+
omega = 0
|
|
125
|
+
|
|
126
|
+
# Extract mua and musp (original logic preserved)
|
|
127
|
+
seg = cfg.get("seg", None)
|
|
128
|
+
if props.shape[0] == nn or props.shape[0] == ne:
|
|
129
|
+
mua = props[:, 0]
|
|
130
|
+
musp = props[:, 1] * (1 - props[:, 2]) if props.shape[1] >= 3 else props[:, 1]
|
|
131
|
+
nref = props[:, 3] if props.shape[1] >= 4 else 1.37
|
|
132
|
+
elif seg is not None:
|
|
133
|
+
seg_idx = np.clip(seg.astype(np.int32), 0, props.shape[0] - 1)
|
|
134
|
+
mua = props[seg_idx, 0]
|
|
135
|
+
musp = (
|
|
136
|
+
props[seg_idx, 1] * (1 - props[seg_idx, 2])
|
|
137
|
+
if props.shape[1] >= 3
|
|
138
|
+
else props[seg_idx, 1]
|
|
139
|
+
)
|
|
140
|
+
nref = props[seg_idx[0], 3] if props.shape[1] >= 4 else 1.37
|
|
141
|
+
else:
|
|
142
|
+
raise ValueError("Property format not recognized")
|
|
143
|
+
|
|
144
|
+
dcoeff = 1.0 / (3.0 * (mua + musp))
|
|
145
|
+
Reff = reff
|
|
146
|
+
|
|
147
|
+
# Pre-allocate lists (faster than repeated extend)
|
|
148
|
+
rows_list = []
|
|
149
|
+
cols_list = []
|
|
150
|
+
vals_list = []
|
|
151
|
+
|
|
152
|
+
offdiag_idx = [1, 2, 3, 5, 6, 8]
|
|
153
|
+
pairs = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
|
|
154
|
+
|
|
155
|
+
# Element-based assembly (original algorithm)
|
|
156
|
+
if len(mua) == ne:
|
|
157
|
+
for k, (i, j) in enumerate(pairs):
|
|
158
|
+
rows_list.append(elem_0[:, i])
|
|
159
|
+
cols_list.append(elem_0[:, j])
|
|
160
|
+
val = deldotdel_mat[:, offdiag_idx[k]] * dcoeff + 0.05 * mua * evol
|
|
161
|
+
if omega > 0:
|
|
162
|
+
val = val.astype(complex) + 1j * 0.05 * omega * R_C0 * nref * evol
|
|
163
|
+
vals_list.append(val)
|
|
164
|
+
|
|
165
|
+
rows_list.append(elem_0[:, j])
|
|
166
|
+
cols_list.append(elem_0[:, i])
|
|
167
|
+
vals_list.append(val)
|
|
168
|
+
|
|
169
|
+
diag_idx = [0, 4, 7, 9]
|
|
170
|
+
for k in range(4):
|
|
171
|
+
rows_list.append(elem_0[:, k])
|
|
172
|
+
cols_list.append(elem_0[:, k])
|
|
173
|
+
val = deldotdel_mat[:, diag_idx[k]] * dcoeff + 0.10 * mua * evol
|
|
174
|
+
if omega > 0:
|
|
175
|
+
val = val.astype(complex) + 1j * 0.10 * omega * R_C0 * nref * evol
|
|
176
|
+
vals_list.append(val)
|
|
177
|
+
else:
|
|
178
|
+
# Node-based properties (original algorithm)
|
|
179
|
+
w1 = (1 / 120) * np.array(
|
|
180
|
+
[
|
|
181
|
+
[2, 2, 1, 1],
|
|
182
|
+
[2, 1, 2, 1],
|
|
183
|
+
[2, 1, 1, 2],
|
|
184
|
+
[1, 2, 2, 1],
|
|
185
|
+
[1, 2, 1, 2],
|
|
186
|
+
[1, 1, 2, 2],
|
|
187
|
+
]
|
|
188
|
+
).T
|
|
189
|
+
w2 = (1 / 60) * (np.diag([2, 2, 2, 2]) + 1)
|
|
190
|
+
|
|
191
|
+
mua_e = mua[elem_0]
|
|
192
|
+
dcoeff_e = np.mean(dcoeff[elem_0], axis=1)
|
|
193
|
+
nref_e = nref[elem_0] if hasattr(nref, "__len__") and len(nref) == nn else nref
|
|
194
|
+
|
|
195
|
+
for k, (i, j) in enumerate(pairs):
|
|
196
|
+
rows_list.append(elem_0[:, i])
|
|
197
|
+
cols_list.append(elem_0[:, j])
|
|
198
|
+
val = (
|
|
199
|
+
deldotdel_mat[:, offdiag_idx[k]] * dcoeff_e + (mua_e @ w1[:, k]) * evol
|
|
200
|
+
)
|
|
201
|
+
if omega > 0:
|
|
202
|
+
if hasattr(nref_e, "__len__"):
|
|
203
|
+
val = (
|
|
204
|
+
val.astype(complex)
|
|
205
|
+
+ 1j * omega * R_C0 * (nref_e @ w1[:, k]) * evol
|
|
206
|
+
)
|
|
207
|
+
else:
|
|
208
|
+
val = val.astype(complex) + 1j * omega * R_C0 * nref_e * 0.05 * evol
|
|
209
|
+
vals_list.append(val)
|
|
210
|
+
|
|
211
|
+
rows_list.append(elem_0[:, j])
|
|
212
|
+
cols_list.append(elem_0[:, i])
|
|
213
|
+
vals_list.append(val)
|
|
214
|
+
|
|
215
|
+
diag_idx = [0, 4, 7, 9]
|
|
216
|
+
for k in range(4):
|
|
217
|
+
rows_list.append(elem_0[:, k])
|
|
218
|
+
cols_list.append(elem_0[:, k])
|
|
219
|
+
val = deldotdel_mat[:, diag_idx[k]] * dcoeff_e + (mua_e @ w2[:, k]) * evol
|
|
220
|
+
if omega > 0:
|
|
221
|
+
if hasattr(nref_e, "__len__"):
|
|
222
|
+
val = (
|
|
223
|
+
val.astype(complex)
|
|
224
|
+
+ 1j * omega * R_C0 * (nref_e @ w2[:, k]) * evol
|
|
225
|
+
)
|
|
226
|
+
else:
|
|
227
|
+
val = val.astype(complex) + 1j * omega * R_C0 * nref_e * 0.10 * evol
|
|
228
|
+
vals_list.append(val)
|
|
229
|
+
|
|
230
|
+
# Boundary condition (original algorithm)
|
|
231
|
+
bc_coeff = (1 - Reff) / (12.0 * (1 + Reff))
|
|
232
|
+
Adiagbc = area * bc_coeff
|
|
233
|
+
Aoffdbc = Adiagbc * 0.5
|
|
234
|
+
|
|
235
|
+
for i, j in [(0, 1), (0, 2), (1, 2)]:
|
|
236
|
+
rows_list.append(face_0[:, i])
|
|
237
|
+
cols_list.append(face_0[:, j])
|
|
238
|
+
vals_list.append(Aoffdbc)
|
|
239
|
+
rows_list.append(face_0[:, j])
|
|
240
|
+
cols_list.append(face_0[:, i])
|
|
241
|
+
vals_list.append(Aoffdbc)
|
|
242
|
+
|
|
243
|
+
for k in range(3):
|
|
244
|
+
rows_list.append(face_0[:, k])
|
|
245
|
+
cols_list.append(face_0[:, k])
|
|
246
|
+
vals_list.append(Adiagbc)
|
|
247
|
+
|
|
248
|
+
# Concatenate all arrays at once (faster than repeated extend)
|
|
249
|
+
rows = np.concatenate(rows_list)
|
|
250
|
+
cols = np.concatenate(cols_list)
|
|
251
|
+
vals = np.concatenate(vals_list)
|
|
252
|
+
|
|
253
|
+
dtype = complex if omega > 0 else float
|
|
254
|
+
Amat = sparse.coo_matrix((vals, (rows, cols)), shape=(nn, nn), dtype=dtype).tocsr()
|
|
255
|
+
|
|
256
|
+
return Amat
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def femrhs(
|
|
260
|
+
cfg: dict, sd: dict = None, wv: str = "", md: int = 1
|
|
261
|
+
) -> Tuple[sparse.spmatrix, np.ndarray, np.ndarray, np.ndarray]:
|
|
262
|
+
"""
|
|
263
|
+
Create right-hand-side vectors for FEM system.
|
|
264
|
+
|
|
265
|
+
Returns
|
|
266
|
+
-------
|
|
267
|
+
rhs : sparse matrix (Nn x Ncols)
|
|
268
|
+
RHS vectors. Column order: [point_src, wide_src, point_det, wide_det]
|
|
269
|
+
loc : ndarray
|
|
270
|
+
Element IDs enclosing each optode (1-based, NaN for wide-field)
|
|
271
|
+
bary : ndarray
|
|
272
|
+
Barycentric coordinates for point optodes
|
|
273
|
+
optode : ndarray
|
|
274
|
+
Combined optode positions
|
|
275
|
+
"""
|
|
276
|
+
import iso2mesh as i2m
|
|
277
|
+
|
|
278
|
+
optsrc, optdet, widesrc, widedet = getoptodes(cfg, wv)
|
|
279
|
+
|
|
280
|
+
# Get counts
|
|
281
|
+
srcnum = optsrc.shape[0] if optsrc is not None and optsrc.size > 0 else 0
|
|
282
|
+
detnum = optdet.shape[0] if optdet is not None and optdet.size > 0 else 0
|
|
283
|
+
|
|
284
|
+
# widesrc/widedet are stored as (Nn x Npattern) in cfg
|
|
285
|
+
# But internally we work with (Npattern x Nn) for easier indexing
|
|
286
|
+
wfsrcnum = widesrc.shape[1] if widesrc is not None and widesrc.size > 0 else 0
|
|
287
|
+
wfdetnum = widedet.shape[1] if widedet is not None and widedet.size > 0 else 0
|
|
288
|
+
|
|
289
|
+
nn = cfg["node"].shape[0]
|
|
290
|
+
total_cols = srcnum + wfsrcnum + detnum + wfdetnum
|
|
291
|
+
|
|
292
|
+
if total_cols == 0:
|
|
293
|
+
return (
|
|
294
|
+
sparse.csr_matrix((nn, 0)),
|
|
295
|
+
np.array([]),
|
|
296
|
+
np.array([]).reshape(0, 4),
|
|
297
|
+
np.array([]),
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
rhs = sparse.lil_matrix((nn, total_cols))
|
|
301
|
+
|
|
302
|
+
# Initialize loc and bary for ALL optodes (including wide-field as NaN)
|
|
303
|
+
total_optodes = srcnum + wfsrcnum + detnum + wfdetnum
|
|
304
|
+
loc = np.full(total_optodes, np.nan)
|
|
305
|
+
bary = np.full((total_optodes, 4), np.nan)
|
|
306
|
+
|
|
307
|
+
# elem is 1-based, tsearchn expects 1-based and returns 1-based
|
|
308
|
+
elem = cfg["elem"][:, :4].astype(np.int32)
|
|
309
|
+
elem_0 = elem - 1 # 0-based for indexing
|
|
310
|
+
|
|
311
|
+
col_idx = 0
|
|
312
|
+
|
|
313
|
+
# Process point sources using iso2mesh.tsearchn
|
|
314
|
+
if srcnum > 0:
|
|
315
|
+
optsrc = np.atleast_2d(optsrc)
|
|
316
|
+
locsrc, barysrc = i2m.tsearchn(cfg["node"], elem, optsrc[:, :3])
|
|
317
|
+
|
|
318
|
+
for i in range(srcnum):
|
|
319
|
+
if not np.isnan(locsrc[i]):
|
|
320
|
+
eid = int(locsrc[i]) - 1 # Convert to 0-based
|
|
321
|
+
rhs[elem_0[eid, :], col_idx + i] = barysrc[i, :]
|
|
322
|
+
|
|
323
|
+
# Store in loc/bary (keep 1-based for loc)
|
|
324
|
+
loc[:srcnum] = locsrc
|
|
325
|
+
bary[:srcnum, :] = barysrc
|
|
326
|
+
col_idx += srcnum
|
|
327
|
+
|
|
328
|
+
# Process widefield sources - widesrc is (Nn x wfsrcnum)
|
|
329
|
+
if wfsrcnum > 0:
|
|
330
|
+
rhs[:, col_idx : col_idx + wfsrcnum] = widesrc
|
|
331
|
+
# loc/bary already NaN for wide-field indices
|
|
332
|
+
col_idx += wfsrcnum
|
|
333
|
+
|
|
334
|
+
# Process point detectors using iso2mesh.tsearchn
|
|
335
|
+
if detnum > 0:
|
|
336
|
+
optdet = np.atleast_2d(optdet)
|
|
337
|
+
locdet, barydet = i2m.tsearchn(cfg["node"], elem, optdet[:, :3])
|
|
338
|
+
|
|
339
|
+
for i in range(detnum):
|
|
340
|
+
if not np.isnan(locdet[i]):
|
|
341
|
+
eid = int(locdet[i]) - 1 # Convert to 0-based
|
|
342
|
+
rhs[elem_0[eid, :], col_idx + i] = barydet[i, :]
|
|
343
|
+
|
|
344
|
+
# Store in loc/bary
|
|
345
|
+
det_start = srcnum + wfsrcnum
|
|
346
|
+
loc[det_start : det_start + detnum] = locdet
|
|
347
|
+
bary[det_start : det_start + detnum, :] = barydet
|
|
348
|
+
col_idx += detnum
|
|
349
|
+
|
|
350
|
+
# Process widefield detectors - widedet is (Nn x wfdetnum)
|
|
351
|
+
if wfdetnum > 0:
|
|
352
|
+
rhs[:, col_idx : col_idx + wfdetnum] = widedet
|
|
353
|
+
|
|
354
|
+
# Combine optode positions
|
|
355
|
+
optode_list = []
|
|
356
|
+
if srcnum > 0:
|
|
357
|
+
optode_list.append(optsrc)
|
|
358
|
+
if detnum > 0:
|
|
359
|
+
optode_list.append(optdet)
|
|
360
|
+
optode = np.vstack(optode_list) if optode_list else np.array([])
|
|
361
|
+
|
|
362
|
+
return rhs.tocsr(), loc, bary, optode
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def femgetdet(
|
|
366
|
+
phi: np.ndarray,
|
|
367
|
+
cfg: dict,
|
|
368
|
+
rhs: np.ndarray,
|
|
369
|
+
loc: np.ndarray = None,
|
|
370
|
+
bary: np.ndarray = None,
|
|
371
|
+
) -> np.ndarray:
|
|
372
|
+
"""
|
|
373
|
+
Extract detector measurements from forward solution.
|
|
374
|
+
|
|
375
|
+
Parameters
|
|
376
|
+
----------
|
|
377
|
+
phi : ndarray
|
|
378
|
+
Forward solution (nn x nsrc_total)
|
|
379
|
+
cfg : dict
|
|
380
|
+
Configuration with srcpos, detpos, widesrc, widedet, etc.
|
|
381
|
+
rhs : ndarray or sparse matrix
|
|
382
|
+
RHS matrix from femrhs (nn x total_cols)
|
|
383
|
+
loc : ndarray, optional
|
|
384
|
+
Element indices for point optodes (1-based)
|
|
385
|
+
bary : ndarray, optional
|
|
386
|
+
Barycentric coordinates for point optodes
|
|
387
|
+
|
|
388
|
+
Returns
|
|
389
|
+
-------
|
|
390
|
+
detval : ndarray
|
|
391
|
+
Detector values (ndet x nsrc)
|
|
392
|
+
"""
|
|
393
|
+
# Get source/detector counts
|
|
394
|
+
srcnum = 0
|
|
395
|
+
if "srcpos" in cfg and cfg["srcpos"] is not None:
|
|
396
|
+
srcpos = np.atleast_2d(cfg["srcpos"])
|
|
397
|
+
if srcpos.size > 0:
|
|
398
|
+
srcnum = srcpos.shape[0]
|
|
399
|
+
|
|
400
|
+
detnum = 0
|
|
401
|
+
if "detpos" in cfg and cfg["detpos"] is not None:
|
|
402
|
+
detpos = np.atleast_2d(cfg["detpos"])
|
|
403
|
+
if detpos.size > 0:
|
|
404
|
+
detnum = detpos.shape[0]
|
|
405
|
+
|
|
406
|
+
wfsrcnum = 0
|
|
407
|
+
if "widesrc" in cfg and cfg["widesrc"] is not None and cfg["widesrc"].size > 0:
|
|
408
|
+
wfsrcnum = cfg["widesrc"].shape[1] # (Nn x Npattern)
|
|
409
|
+
|
|
410
|
+
wfdetnum = 0
|
|
411
|
+
if "widedet" in cfg and cfg["widedet"] is not None and cfg["widedet"].size > 0:
|
|
412
|
+
wfdetnum = cfg["widedet"].shape[1] # (Nn x Npattern)
|
|
413
|
+
|
|
414
|
+
total_src = srcnum + wfsrcnum
|
|
415
|
+
total_det = detnum + wfdetnum
|
|
416
|
+
|
|
417
|
+
if total_src == 0 or total_det == 0:
|
|
418
|
+
return np.array([])
|
|
419
|
+
|
|
420
|
+
# Column indices in rhs/phi:
|
|
421
|
+
# [0:srcnum] = point sources
|
|
422
|
+
# [srcnum:srcnum+wfsrcnum] = wide sources
|
|
423
|
+
# [srcnum+wfsrcnum:srcnum+wfsrcnum+detnum] = point detectors
|
|
424
|
+
# [srcnum+wfsrcnum+detnum:end] = wide detectors
|
|
425
|
+
|
|
426
|
+
det_col_start = srcnum + wfsrcnum
|
|
427
|
+
det_col_end = det_col_start + total_det
|
|
428
|
+
|
|
429
|
+
# Extract detector RHS columns
|
|
430
|
+
if sparse.issparse(rhs):
|
|
431
|
+
rhs_det = rhs[:, det_col_start:det_col_end].toarray()
|
|
432
|
+
else:
|
|
433
|
+
rhs_det = rhs[:, det_col_start:det_col_end]
|
|
434
|
+
|
|
435
|
+
# Extract source phi columns
|
|
436
|
+
phi_src = phi[:, :total_src]
|
|
437
|
+
|
|
438
|
+
# Compute detector values using adjoint: detval = rhs_det^T @ phi_src
|
|
439
|
+
# Result shape: (total_det x total_src)
|
|
440
|
+
detval = rhs_det.T @ phi_src
|
|
441
|
+
|
|
442
|
+
return detval
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
try:
|
|
446
|
+
from numba import njit, prange
|
|
447
|
+
|
|
448
|
+
HAS_NUMBA = True
|
|
449
|
+
print("Using Numba for Jacobian acceleration")
|
|
450
|
+
except ImportError:
|
|
451
|
+
HAS_NUMBA = False
|
|
452
|
+
print("Numba not available")
|
|
453
|
+
|
|
454
|
+
if HAS_NUMBA:
|
|
455
|
+
|
|
456
|
+
@njit(parallel=True, cache=True)
|
|
457
|
+
def _jac_core(phi, elem_0, evol, src_cols, det_cols):
|
|
458
|
+
"""Numba-accelerated Jacobian core computation."""
|
|
459
|
+
nelem = elem_0.shape[0]
|
|
460
|
+
nsd = len(src_cols)
|
|
461
|
+
Jmua_elem = np.zeros((nsd, nelem))
|
|
462
|
+
|
|
463
|
+
for isd in prange(nsd):
|
|
464
|
+
src_col = src_cols[isd]
|
|
465
|
+
det_col = det_cols[isd]
|
|
466
|
+
|
|
467
|
+
for ie in range(nelem):
|
|
468
|
+
n0, n1, n2, n3 = (
|
|
469
|
+
elem_0[ie, 0],
|
|
470
|
+
elem_0[ie, 1],
|
|
471
|
+
elem_0[ie, 2],
|
|
472
|
+
elem_0[ie, 3],
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
ps0, ps1, ps2, ps3 = (
|
|
476
|
+
phi[n0, src_col],
|
|
477
|
+
phi[n1, src_col],
|
|
478
|
+
phi[n2, src_col],
|
|
479
|
+
phi[n3, src_col],
|
|
480
|
+
)
|
|
481
|
+
pd0, pd1, pd2, pd3 = (
|
|
482
|
+
phi[n0, det_col],
|
|
483
|
+
phi[n1, det_col],
|
|
484
|
+
phi[n2, det_col],
|
|
485
|
+
phi[n3, det_col],
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
diag_sum = ps0 * pd0 + ps1 * pd1 + ps2 * pd2 + ps3 * pd3
|
|
489
|
+
cross_sum = (
|
|
490
|
+
ps0 * pd1
|
|
491
|
+
+ ps1 * pd0
|
|
492
|
+
+ ps0 * pd2
|
|
493
|
+
+ ps2 * pd0
|
|
494
|
+
+ ps0 * pd3
|
|
495
|
+
+ ps3 * pd0
|
|
496
|
+
+ ps1 * pd2
|
|
497
|
+
+ ps2 * pd1
|
|
498
|
+
+ ps1 * pd3
|
|
499
|
+
+ ps3 * pd1
|
|
500
|
+
+ ps2 * pd3
|
|
501
|
+
+ ps3 * pd2
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
Jmua_elem[isd, ie] = -(diag_sum + cross_sum * 0.5) * 0.1 * evol[ie]
|
|
505
|
+
|
|
506
|
+
return Jmua_elem
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def jac(sd, phi, deldotdel_mat, elem, evol, iselem=False):
|
|
510
|
+
"""Build Jacobian matrices - Numba accelerated if available."""
|
|
511
|
+
elem_0 = elem[:, :4].astype(np.int32) - 1
|
|
512
|
+
nelem = elem_0.shape[0]
|
|
513
|
+
nn = phi.shape[0]
|
|
514
|
+
|
|
515
|
+
if sd.shape[1] >= 3:
|
|
516
|
+
active = sd[:, 2] == 1
|
|
517
|
+
sd_active = sd[active, :2].astype(np.int32)
|
|
518
|
+
else:
|
|
519
|
+
sd_active = sd[:, :2].astype(np.int32)
|
|
520
|
+
|
|
521
|
+
nsd = sd_active.shape[0]
|
|
522
|
+
src_cols = sd_active[:, 0]
|
|
523
|
+
det_cols = sd_active[:, 1]
|
|
524
|
+
|
|
525
|
+
if HAS_NUMBA:
|
|
526
|
+
# Use Numba-accelerated version
|
|
527
|
+
Jmua_elem = _jac_core(
|
|
528
|
+
np.ascontiguousarray(phi), elem_0, evol, src_cols, det_cols
|
|
529
|
+
)
|
|
530
|
+
else:
|
|
531
|
+
# Fallback to numpy loop
|
|
532
|
+
Jmua_elem = np.zeros((nsd, nelem), dtype=phi.dtype)
|
|
533
|
+
evol_scaled = 0.1 * evol
|
|
534
|
+
|
|
535
|
+
for isd in range(nsd):
|
|
536
|
+
src_col = src_cols[isd]
|
|
537
|
+
det_col = det_cols[isd]
|
|
538
|
+
|
|
539
|
+
phi_src = phi[elem_0, src_col]
|
|
540
|
+
phi_det = phi[elem_0, det_col]
|
|
541
|
+
|
|
542
|
+
diag_sum = (phi_src * phi_det).sum(axis=1)
|
|
543
|
+
cross_sum = (
|
|
544
|
+
phi_src[:, 0] * phi_det[:, 1]
|
|
545
|
+
+ phi_src[:, 1] * phi_det[:, 0]
|
|
546
|
+
+ phi_src[:, 0] * phi_det[:, 2]
|
|
547
|
+
+ phi_src[:, 2] * phi_det[:, 0]
|
|
548
|
+
+ phi_src[:, 0] * phi_det[:, 3]
|
|
549
|
+
+ phi_src[:, 3] * phi_det[:, 0]
|
|
550
|
+
+ phi_src[:, 1] * phi_det[:, 2]
|
|
551
|
+
+ phi_src[:, 2] * phi_det[:, 1]
|
|
552
|
+
+ phi_src[:, 1] * phi_det[:, 3]
|
|
553
|
+
+ phi_src[:, 3] * phi_det[:, 1]
|
|
554
|
+
+ phi_src[:, 2] * phi_det[:, 3]
|
|
555
|
+
+ phi_src[:, 3] * phi_det[:, 2]
|
|
556
|
+
)
|
|
557
|
+
Jmua_elem[isd, :] = -(diag_sum + cross_sum * 0.5) * evol_scaled
|
|
558
|
+
|
|
559
|
+
# Accumulate to nodes using sparse matrix
|
|
560
|
+
from scipy import sparse
|
|
561
|
+
|
|
562
|
+
rows = elem_0.ravel()
|
|
563
|
+
cols = np.repeat(np.arange(nelem), 4)
|
|
564
|
+
data = np.full(nelem * 4, 0.25)
|
|
565
|
+
P = sparse.csr_matrix((data, (rows, cols)), shape=(nn, nelem))
|
|
566
|
+
|
|
567
|
+
Jmua_node = (P @ Jmua_elem.T).T
|
|
568
|
+
|
|
569
|
+
return Jmua_node, Jmua_elem
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
def jacchrome(Jmua: dict, chromophores: List[str]) -> dict:
|
|
573
|
+
"""Build Jacobian matrices for chromophores from mua Jacobian."""
|
|
574
|
+
|
|
575
|
+
if not isinstance(Jmua, dict):
|
|
576
|
+
raise ValueError("Jmua must be a dict with wavelength keys")
|
|
577
|
+
|
|
578
|
+
wavelengths = list(Jmua.keys())
|
|
579
|
+
extin, _ = extinction(wavelengths, chromophores)
|
|
580
|
+
|
|
581
|
+
Jchrome = {}
|
|
582
|
+
for i, ch in enumerate(chromophores):
|
|
583
|
+
Jch = None
|
|
584
|
+
for j, wv in enumerate(wavelengths):
|
|
585
|
+
weighted = Jmua[wv] * extin[j, i]
|
|
586
|
+
Jch = weighted if Jch is None else np.vstack([Jch, weighted])
|
|
587
|
+
Jchrome[ch] = Jch
|
|
588
|
+
|
|
589
|
+
return Jchrome
|