redbirdpy 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
redbirdpy/forward.py ADDED
@@ -0,0 +1,589 @@
1
+ """
2
+ Redbird Forward Module - FEM-based forward modeling for diffuse optics.
3
+
4
+ INDEX CONVENTION: All mesh indices (elem, face) stored in cfg are 1-based
5
+ to match MATLAB/iso2mesh. This module converts to 0-based internally when
6
+ indexing numpy arrays, using local variables named with '_0' suffix.
7
+
8
+ Functions:
9
+ runforward: Main forward solver for all sources/wavelengths
10
+ femlhs: Build FEM left-hand-side (stiffness) matrix
11
+ femrhs: Build FEM right-hand-side vector
12
+ femgetdet: Extract detector values from forward solution
13
+ jac: Compute Jacobian matrices using adjoint method
14
+ """
15
+
16
+ __all__ = [
17
+ "runforward",
18
+ "femlhs",
19
+ "femrhs",
20
+ "femgetdet",
21
+ "jac",
22
+ "jacchrome",
23
+ "C0",
24
+ ]
25
+
26
+ import numpy as np
27
+ from scipy import sparse
28
+ from typing import Dict, Tuple, Optional, Union, List, Any
29
+
30
+ # Import solver functions from solver module
31
+ from .solver import femsolve
32
+ from .utility import sdmap, getoptodes, deldotdel
33
+ from .property import extinction
34
+
35
+ # Speed of light in mm/s
36
+ C0 = 299792458000.0
37
+ R_C0 = 1.0 / C0
38
+
39
+
40
+ def runforward(cfg: dict, **kwargs) -> Tuple[Any, Any]:
41
+ """
42
+ Perform forward simulations at all sources and all wavelengths.
43
+ """
44
+ solverflag = kwargs.get("solverflag", {})
45
+ rfcw = kwargs.get("rfcw", [1])
46
+ if isinstance(rfcw, int):
47
+ rfcw = [rfcw]
48
+
49
+ if "deldotdel" not in cfg or cfg["deldotdel"] is None:
50
+ cfg["deldotdel"], _ = deldotdel(cfg)
51
+
52
+ wavelengths = [""]
53
+ if isinstance(cfg.get("prop"), dict):
54
+ wavelengths = list(cfg["prop"].keys())
55
+
56
+ sd = kwargs.get("sd")
57
+ if sd is None:
58
+ sd = sdmap(cfg)
59
+ if not isinstance(sd, dict):
60
+ sd = {wv: sd for wv in wavelengths}
61
+
62
+ Amat = {}
63
+ detval_out = {md: {"detphi": {}} for md in rfcw}
64
+ phi_out = {md: {"phi": {}} for md in rfcw}
65
+
66
+ for wv in wavelengths:
67
+ for md in rfcw:
68
+ rhs, loc, bary, optode = femrhs(cfg, sd, wv, md)
69
+ Amat[wv] = femlhs(cfg, cfg["deldotdel"], wv, md)
70
+ phi_sol, flag = femsolve(Amat[wv], rhs, **kwargs)
71
+ phi_out[md]["phi"][wv] = phi_sol
72
+
73
+ # Pass rhs to femgetdet for wide-field detection
74
+ detval = femgetdet(phi_sol, cfg, rhs, loc, bary)
75
+ detval_out[md]["detphi"][wv] = detval
76
+
77
+ if len(wavelengths) == 1:
78
+ Amat = Amat[wavelengths[0]]
79
+ for md in rfcw:
80
+ phi_out[md]["phi"] = phi_out[md]["phi"][wavelengths[0]]
81
+ detval_out[md]["detphi"] = detval_out[md]["detphi"][wavelengths[0]]
82
+
83
+ if len(rfcw) == 1:
84
+ phi_out = phi_out[rfcw[0]]["phi"]
85
+ detval_out = detval_out[rfcw[0]]["detphi"]
86
+
87
+ return detval_out, phi_out
88
+
89
+
90
+ def femlhs(
91
+ cfg: dict, deldotdel_mat: np.ndarray, wavelength: str = "", mode: int = 1
92
+ ) -> sparse.csr_matrix:
93
+ """
94
+ Create FEM stiffness matrix - optimized assembly with original algorithm.
95
+ """
96
+ nn = cfg["node"].shape[0]
97
+ ne = cfg["elem"].shape[0]
98
+ evol = cfg["evol"]
99
+ area = cfg["area"]
100
+
101
+ # Convert 1-based to 0-based
102
+ elem_0 = cfg["elem"][:, :4].astype(np.int32) - 1
103
+ face_0 = cfg["face"].astype(np.int32) - 1
104
+
105
+ # Get properties for current wavelength
106
+ if isinstance(cfg.get("prop"), dict) and wavelength:
107
+ props = cfg["prop"][wavelength]
108
+ reff = (
109
+ cfg["reff"][wavelength]
110
+ if isinstance(cfg.get("reff"), dict)
111
+ else cfg["reff"]
112
+ )
113
+ omega = (
114
+ cfg["omega"].get(wavelength, 0)
115
+ if isinstance(cfg.get("omega"), dict)
116
+ else cfg.get("omega", 0)
117
+ )
118
+ else:
119
+ props = cfg["prop"]
120
+ reff = cfg.get("reff", 0.493)
121
+ omega = cfg.get("omega", 0)
122
+
123
+ if mode == 2:
124
+ omega = 0
125
+
126
+ # Extract mua and musp (original logic preserved)
127
+ seg = cfg.get("seg", None)
128
+ if props.shape[0] == nn or props.shape[0] == ne:
129
+ mua = props[:, 0]
130
+ musp = props[:, 1] * (1 - props[:, 2]) if props.shape[1] >= 3 else props[:, 1]
131
+ nref = props[:, 3] if props.shape[1] >= 4 else 1.37
132
+ elif seg is not None:
133
+ seg_idx = np.clip(seg.astype(np.int32), 0, props.shape[0] - 1)
134
+ mua = props[seg_idx, 0]
135
+ musp = (
136
+ props[seg_idx, 1] * (1 - props[seg_idx, 2])
137
+ if props.shape[1] >= 3
138
+ else props[seg_idx, 1]
139
+ )
140
+ nref = props[seg_idx[0], 3] if props.shape[1] >= 4 else 1.37
141
+ else:
142
+ raise ValueError("Property format not recognized")
143
+
144
+ dcoeff = 1.0 / (3.0 * (mua + musp))
145
+ Reff = reff
146
+
147
+ # Pre-allocate lists (faster than repeated extend)
148
+ rows_list = []
149
+ cols_list = []
150
+ vals_list = []
151
+
152
+ offdiag_idx = [1, 2, 3, 5, 6, 8]
153
+ pairs = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
154
+
155
+ # Element-based assembly (original algorithm)
156
+ if len(mua) == ne:
157
+ for k, (i, j) in enumerate(pairs):
158
+ rows_list.append(elem_0[:, i])
159
+ cols_list.append(elem_0[:, j])
160
+ val = deldotdel_mat[:, offdiag_idx[k]] * dcoeff + 0.05 * mua * evol
161
+ if omega > 0:
162
+ val = val.astype(complex) + 1j * 0.05 * omega * R_C0 * nref * evol
163
+ vals_list.append(val)
164
+
165
+ rows_list.append(elem_0[:, j])
166
+ cols_list.append(elem_0[:, i])
167
+ vals_list.append(val)
168
+
169
+ diag_idx = [0, 4, 7, 9]
170
+ for k in range(4):
171
+ rows_list.append(elem_0[:, k])
172
+ cols_list.append(elem_0[:, k])
173
+ val = deldotdel_mat[:, diag_idx[k]] * dcoeff + 0.10 * mua * evol
174
+ if omega > 0:
175
+ val = val.astype(complex) + 1j * 0.10 * omega * R_C0 * nref * evol
176
+ vals_list.append(val)
177
+ else:
178
+ # Node-based properties (original algorithm)
179
+ w1 = (1 / 120) * np.array(
180
+ [
181
+ [2, 2, 1, 1],
182
+ [2, 1, 2, 1],
183
+ [2, 1, 1, 2],
184
+ [1, 2, 2, 1],
185
+ [1, 2, 1, 2],
186
+ [1, 1, 2, 2],
187
+ ]
188
+ ).T
189
+ w2 = (1 / 60) * (np.diag([2, 2, 2, 2]) + 1)
190
+
191
+ mua_e = mua[elem_0]
192
+ dcoeff_e = np.mean(dcoeff[elem_0], axis=1)
193
+ nref_e = nref[elem_0] if hasattr(nref, "__len__") and len(nref) == nn else nref
194
+
195
+ for k, (i, j) in enumerate(pairs):
196
+ rows_list.append(elem_0[:, i])
197
+ cols_list.append(elem_0[:, j])
198
+ val = (
199
+ deldotdel_mat[:, offdiag_idx[k]] * dcoeff_e + (mua_e @ w1[:, k]) * evol
200
+ )
201
+ if omega > 0:
202
+ if hasattr(nref_e, "__len__"):
203
+ val = (
204
+ val.astype(complex)
205
+ + 1j * omega * R_C0 * (nref_e @ w1[:, k]) * evol
206
+ )
207
+ else:
208
+ val = val.astype(complex) + 1j * omega * R_C0 * nref_e * 0.05 * evol
209
+ vals_list.append(val)
210
+
211
+ rows_list.append(elem_0[:, j])
212
+ cols_list.append(elem_0[:, i])
213
+ vals_list.append(val)
214
+
215
+ diag_idx = [0, 4, 7, 9]
216
+ for k in range(4):
217
+ rows_list.append(elem_0[:, k])
218
+ cols_list.append(elem_0[:, k])
219
+ val = deldotdel_mat[:, diag_idx[k]] * dcoeff_e + (mua_e @ w2[:, k]) * evol
220
+ if omega > 0:
221
+ if hasattr(nref_e, "__len__"):
222
+ val = (
223
+ val.astype(complex)
224
+ + 1j * omega * R_C0 * (nref_e @ w2[:, k]) * evol
225
+ )
226
+ else:
227
+ val = val.astype(complex) + 1j * omega * R_C0 * nref_e * 0.10 * evol
228
+ vals_list.append(val)
229
+
230
+ # Boundary condition (original algorithm)
231
+ bc_coeff = (1 - Reff) / (12.0 * (1 + Reff))
232
+ Adiagbc = area * bc_coeff
233
+ Aoffdbc = Adiagbc * 0.5
234
+
235
+ for i, j in [(0, 1), (0, 2), (1, 2)]:
236
+ rows_list.append(face_0[:, i])
237
+ cols_list.append(face_0[:, j])
238
+ vals_list.append(Aoffdbc)
239
+ rows_list.append(face_0[:, j])
240
+ cols_list.append(face_0[:, i])
241
+ vals_list.append(Aoffdbc)
242
+
243
+ for k in range(3):
244
+ rows_list.append(face_0[:, k])
245
+ cols_list.append(face_0[:, k])
246
+ vals_list.append(Adiagbc)
247
+
248
+ # Concatenate all arrays at once (faster than repeated extend)
249
+ rows = np.concatenate(rows_list)
250
+ cols = np.concatenate(cols_list)
251
+ vals = np.concatenate(vals_list)
252
+
253
+ dtype = complex if omega > 0 else float
254
+ Amat = sparse.coo_matrix((vals, (rows, cols)), shape=(nn, nn), dtype=dtype).tocsr()
255
+
256
+ return Amat
257
+
258
+
259
+ def femrhs(
260
+ cfg: dict, sd: dict = None, wv: str = "", md: int = 1
261
+ ) -> Tuple[sparse.spmatrix, np.ndarray, np.ndarray, np.ndarray]:
262
+ """
263
+ Create right-hand-side vectors for FEM system.
264
+
265
+ Returns
266
+ -------
267
+ rhs : sparse matrix (Nn x Ncols)
268
+ RHS vectors. Column order: [point_src, wide_src, point_det, wide_det]
269
+ loc : ndarray
270
+ Element IDs enclosing each optode (1-based, NaN for wide-field)
271
+ bary : ndarray
272
+ Barycentric coordinates for point optodes
273
+ optode : ndarray
274
+ Combined optode positions
275
+ """
276
+ import iso2mesh as i2m
277
+
278
+ optsrc, optdet, widesrc, widedet = getoptodes(cfg, wv)
279
+
280
+ # Get counts
281
+ srcnum = optsrc.shape[0] if optsrc is not None and optsrc.size > 0 else 0
282
+ detnum = optdet.shape[0] if optdet is not None and optdet.size > 0 else 0
283
+
284
+ # widesrc/widedet are stored as (Nn x Npattern) in cfg
285
+ # But internally we work with (Npattern x Nn) for easier indexing
286
+ wfsrcnum = widesrc.shape[1] if widesrc is not None and widesrc.size > 0 else 0
287
+ wfdetnum = widedet.shape[1] if widedet is not None and widedet.size > 0 else 0
288
+
289
+ nn = cfg["node"].shape[0]
290
+ total_cols = srcnum + wfsrcnum + detnum + wfdetnum
291
+
292
+ if total_cols == 0:
293
+ return (
294
+ sparse.csr_matrix((nn, 0)),
295
+ np.array([]),
296
+ np.array([]).reshape(0, 4),
297
+ np.array([]),
298
+ )
299
+
300
+ rhs = sparse.lil_matrix((nn, total_cols))
301
+
302
+ # Initialize loc and bary for ALL optodes (including wide-field as NaN)
303
+ total_optodes = srcnum + wfsrcnum + detnum + wfdetnum
304
+ loc = np.full(total_optodes, np.nan)
305
+ bary = np.full((total_optodes, 4), np.nan)
306
+
307
+ # elem is 1-based, tsearchn expects 1-based and returns 1-based
308
+ elem = cfg["elem"][:, :4].astype(np.int32)
309
+ elem_0 = elem - 1 # 0-based for indexing
310
+
311
+ col_idx = 0
312
+
313
+ # Process point sources using iso2mesh.tsearchn
314
+ if srcnum > 0:
315
+ optsrc = np.atleast_2d(optsrc)
316
+ locsrc, barysrc = i2m.tsearchn(cfg["node"], elem, optsrc[:, :3])
317
+
318
+ for i in range(srcnum):
319
+ if not np.isnan(locsrc[i]):
320
+ eid = int(locsrc[i]) - 1 # Convert to 0-based
321
+ rhs[elem_0[eid, :], col_idx + i] = barysrc[i, :]
322
+
323
+ # Store in loc/bary (keep 1-based for loc)
324
+ loc[:srcnum] = locsrc
325
+ bary[:srcnum, :] = barysrc
326
+ col_idx += srcnum
327
+
328
+ # Process widefield sources - widesrc is (Nn x wfsrcnum)
329
+ if wfsrcnum > 0:
330
+ rhs[:, col_idx : col_idx + wfsrcnum] = widesrc
331
+ # loc/bary already NaN for wide-field indices
332
+ col_idx += wfsrcnum
333
+
334
+ # Process point detectors using iso2mesh.tsearchn
335
+ if detnum > 0:
336
+ optdet = np.atleast_2d(optdet)
337
+ locdet, barydet = i2m.tsearchn(cfg["node"], elem, optdet[:, :3])
338
+
339
+ for i in range(detnum):
340
+ if not np.isnan(locdet[i]):
341
+ eid = int(locdet[i]) - 1 # Convert to 0-based
342
+ rhs[elem_0[eid, :], col_idx + i] = barydet[i, :]
343
+
344
+ # Store in loc/bary
345
+ det_start = srcnum + wfsrcnum
346
+ loc[det_start : det_start + detnum] = locdet
347
+ bary[det_start : det_start + detnum, :] = barydet
348
+ col_idx += detnum
349
+
350
+ # Process widefield detectors - widedet is (Nn x wfdetnum)
351
+ if wfdetnum > 0:
352
+ rhs[:, col_idx : col_idx + wfdetnum] = widedet
353
+
354
+ # Combine optode positions
355
+ optode_list = []
356
+ if srcnum > 0:
357
+ optode_list.append(optsrc)
358
+ if detnum > 0:
359
+ optode_list.append(optdet)
360
+ optode = np.vstack(optode_list) if optode_list else np.array([])
361
+
362
+ return rhs.tocsr(), loc, bary, optode
363
+
364
+
365
+ def femgetdet(
366
+ phi: np.ndarray,
367
+ cfg: dict,
368
+ rhs: np.ndarray,
369
+ loc: np.ndarray = None,
370
+ bary: np.ndarray = None,
371
+ ) -> np.ndarray:
372
+ """
373
+ Extract detector measurements from forward solution.
374
+
375
+ Parameters
376
+ ----------
377
+ phi : ndarray
378
+ Forward solution (nn x nsrc_total)
379
+ cfg : dict
380
+ Configuration with srcpos, detpos, widesrc, widedet, etc.
381
+ rhs : ndarray or sparse matrix
382
+ RHS matrix from femrhs (nn x total_cols)
383
+ loc : ndarray, optional
384
+ Element indices for point optodes (1-based)
385
+ bary : ndarray, optional
386
+ Barycentric coordinates for point optodes
387
+
388
+ Returns
389
+ -------
390
+ detval : ndarray
391
+ Detector values (ndet x nsrc)
392
+ """
393
+ # Get source/detector counts
394
+ srcnum = 0
395
+ if "srcpos" in cfg and cfg["srcpos"] is not None:
396
+ srcpos = np.atleast_2d(cfg["srcpos"])
397
+ if srcpos.size > 0:
398
+ srcnum = srcpos.shape[0]
399
+
400
+ detnum = 0
401
+ if "detpos" in cfg and cfg["detpos"] is not None:
402
+ detpos = np.atleast_2d(cfg["detpos"])
403
+ if detpos.size > 0:
404
+ detnum = detpos.shape[0]
405
+
406
+ wfsrcnum = 0
407
+ if "widesrc" in cfg and cfg["widesrc"] is not None and cfg["widesrc"].size > 0:
408
+ wfsrcnum = cfg["widesrc"].shape[1] # (Nn x Npattern)
409
+
410
+ wfdetnum = 0
411
+ if "widedet" in cfg and cfg["widedet"] is not None and cfg["widedet"].size > 0:
412
+ wfdetnum = cfg["widedet"].shape[1] # (Nn x Npattern)
413
+
414
+ total_src = srcnum + wfsrcnum
415
+ total_det = detnum + wfdetnum
416
+
417
+ if total_src == 0 or total_det == 0:
418
+ return np.array([])
419
+
420
+ # Column indices in rhs/phi:
421
+ # [0:srcnum] = point sources
422
+ # [srcnum:srcnum+wfsrcnum] = wide sources
423
+ # [srcnum+wfsrcnum:srcnum+wfsrcnum+detnum] = point detectors
424
+ # [srcnum+wfsrcnum+detnum:end] = wide detectors
425
+
426
+ det_col_start = srcnum + wfsrcnum
427
+ det_col_end = det_col_start + total_det
428
+
429
+ # Extract detector RHS columns
430
+ if sparse.issparse(rhs):
431
+ rhs_det = rhs[:, det_col_start:det_col_end].toarray()
432
+ else:
433
+ rhs_det = rhs[:, det_col_start:det_col_end]
434
+
435
+ # Extract source phi columns
436
+ phi_src = phi[:, :total_src]
437
+
438
+ # Compute detector values using adjoint: detval = rhs_det^T @ phi_src
439
+ # Result shape: (total_det x total_src)
440
+ detval = rhs_det.T @ phi_src
441
+
442
+ return detval
443
+
444
+
445
+ try:
446
+ from numba import njit, prange
447
+
448
+ HAS_NUMBA = True
449
+ print("Using Numba for Jacobian acceleration")
450
+ except ImportError:
451
+ HAS_NUMBA = False
452
+ print("Numba not available")
453
+
454
+ if HAS_NUMBA:
455
+
456
+ @njit(parallel=True, cache=True)
457
+ def _jac_core(phi, elem_0, evol, src_cols, det_cols):
458
+ """Numba-accelerated Jacobian core computation."""
459
+ nelem = elem_0.shape[0]
460
+ nsd = len(src_cols)
461
+ Jmua_elem = np.zeros((nsd, nelem))
462
+
463
+ for isd in prange(nsd):
464
+ src_col = src_cols[isd]
465
+ det_col = det_cols[isd]
466
+
467
+ for ie in range(nelem):
468
+ n0, n1, n2, n3 = (
469
+ elem_0[ie, 0],
470
+ elem_0[ie, 1],
471
+ elem_0[ie, 2],
472
+ elem_0[ie, 3],
473
+ )
474
+
475
+ ps0, ps1, ps2, ps3 = (
476
+ phi[n0, src_col],
477
+ phi[n1, src_col],
478
+ phi[n2, src_col],
479
+ phi[n3, src_col],
480
+ )
481
+ pd0, pd1, pd2, pd3 = (
482
+ phi[n0, det_col],
483
+ phi[n1, det_col],
484
+ phi[n2, det_col],
485
+ phi[n3, det_col],
486
+ )
487
+
488
+ diag_sum = ps0 * pd0 + ps1 * pd1 + ps2 * pd2 + ps3 * pd3
489
+ cross_sum = (
490
+ ps0 * pd1
491
+ + ps1 * pd0
492
+ + ps0 * pd2
493
+ + ps2 * pd0
494
+ + ps0 * pd3
495
+ + ps3 * pd0
496
+ + ps1 * pd2
497
+ + ps2 * pd1
498
+ + ps1 * pd3
499
+ + ps3 * pd1
500
+ + ps2 * pd3
501
+ + ps3 * pd2
502
+ )
503
+
504
+ Jmua_elem[isd, ie] = -(diag_sum + cross_sum * 0.5) * 0.1 * evol[ie]
505
+
506
+ return Jmua_elem
507
+
508
+
509
+ def jac(sd, phi, deldotdel_mat, elem, evol, iselem=False):
510
+ """Build Jacobian matrices - Numba accelerated if available."""
511
+ elem_0 = elem[:, :4].astype(np.int32) - 1
512
+ nelem = elem_0.shape[0]
513
+ nn = phi.shape[0]
514
+
515
+ if sd.shape[1] >= 3:
516
+ active = sd[:, 2] == 1
517
+ sd_active = sd[active, :2].astype(np.int32)
518
+ else:
519
+ sd_active = sd[:, :2].astype(np.int32)
520
+
521
+ nsd = sd_active.shape[0]
522
+ src_cols = sd_active[:, 0]
523
+ det_cols = sd_active[:, 1]
524
+
525
+ if HAS_NUMBA:
526
+ # Use Numba-accelerated version
527
+ Jmua_elem = _jac_core(
528
+ np.ascontiguousarray(phi), elem_0, evol, src_cols, det_cols
529
+ )
530
+ else:
531
+ # Fallback to numpy loop
532
+ Jmua_elem = np.zeros((nsd, nelem), dtype=phi.dtype)
533
+ evol_scaled = 0.1 * evol
534
+
535
+ for isd in range(nsd):
536
+ src_col = src_cols[isd]
537
+ det_col = det_cols[isd]
538
+
539
+ phi_src = phi[elem_0, src_col]
540
+ phi_det = phi[elem_0, det_col]
541
+
542
+ diag_sum = (phi_src * phi_det).sum(axis=1)
543
+ cross_sum = (
544
+ phi_src[:, 0] * phi_det[:, 1]
545
+ + phi_src[:, 1] * phi_det[:, 0]
546
+ + phi_src[:, 0] * phi_det[:, 2]
547
+ + phi_src[:, 2] * phi_det[:, 0]
548
+ + phi_src[:, 0] * phi_det[:, 3]
549
+ + phi_src[:, 3] * phi_det[:, 0]
550
+ + phi_src[:, 1] * phi_det[:, 2]
551
+ + phi_src[:, 2] * phi_det[:, 1]
552
+ + phi_src[:, 1] * phi_det[:, 3]
553
+ + phi_src[:, 3] * phi_det[:, 1]
554
+ + phi_src[:, 2] * phi_det[:, 3]
555
+ + phi_src[:, 3] * phi_det[:, 2]
556
+ )
557
+ Jmua_elem[isd, :] = -(diag_sum + cross_sum * 0.5) * evol_scaled
558
+
559
+ # Accumulate to nodes using sparse matrix
560
+ from scipy import sparse
561
+
562
+ rows = elem_0.ravel()
563
+ cols = np.repeat(np.arange(nelem), 4)
564
+ data = np.full(nelem * 4, 0.25)
565
+ P = sparse.csr_matrix((data, (rows, cols)), shape=(nn, nelem))
566
+
567
+ Jmua_node = (P @ Jmua_elem.T).T
568
+
569
+ return Jmua_node, Jmua_elem
570
+
571
+
572
+ def jacchrome(Jmua: dict, chromophores: List[str]) -> dict:
573
+ """Build Jacobian matrices for chromophores from mua Jacobian."""
574
+
575
+ if not isinstance(Jmua, dict):
576
+ raise ValueError("Jmua must be a dict with wavelength keys")
577
+
578
+ wavelengths = list(Jmua.keys())
579
+ extin, _ = extinction(wavelengths, chromophores)
580
+
581
+ Jchrome = {}
582
+ for i, ch in enumerate(chromophores):
583
+ Jch = None
584
+ for j, wv in enumerate(wavelengths):
585
+ weighted = Jmua[wv] * extin[j, i]
586
+ Jch = weighted if Jch is None else np.vstack([Jch, weighted])
587
+ Jchrome[ch] = Jch
588
+
589
+ return Jchrome