isoview 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
isoview/fusion.py ADDED
@@ -0,0 +1,979 @@
1
+ """Multi-view fusion for IsoView microscopy data."""
2
+
3
+ from pathlib import Path
4
+ import numpy as np
5
+ from typing import Literal
6
+
7
+ try:
8
+ import pywt
9
+ PYWT_AVAILABLE = True
10
+ except ImportError:
11
+ PYWT_AVAILABLE = False
12
+
13
+ from .config import ProcessingConfig
14
+ from .io import read_volume, write_volume
15
+ from .transforms import (
16
+ estimate_registration,
17
+ apply_registration,
18
+ apply_registration_to_mask,
19
+ estimate_channel_transform,
20
+ estimate_camera_transform,
21
+ apply_channel_transform,
22
+ apply_camera_transform,
23
+ )
24
+ from .intensity import estimate_intensity_correction, apply_intensity_correction, estimate_background, subtract_background
25
+ from .masks import create_fusion_mask, create_slice_mask, create_average_mask, remove_mask_anomalies
26
+
27
+
28
+ def _find_volume_file(tm_folder: Path, specimen: int, timepoint: int, camera: int, channel: int) -> Path | None:
29
+ """Find volume file for given parameters."""
30
+ base = f"SPM{specimen:02d}_TM{timepoint:06d}_CM{camera:02d}_CHN{channel:02d}"
31
+ for ext in [".klb", ".tif", ".zarr"]:
32
+ path = tm_folder / f"{base}{ext}"
33
+ if path.exists():
34
+ return path
35
+ return None
36
+
37
+
38
+ def _find_mask_file(tm_folder: Path, specimen: int, timepoint: int, camera: int, channel: int) -> Path | None:
39
+ """Find segmentation mask file for given parameters."""
40
+ base = f"SPM{specimen:02d}_TM{timepoint:06d}_CM{camera:02d}_CHN{channel:02d}.segmentationMask"
41
+ for ext in [".klb", ".tif", ".zarr"]:
42
+ path = tm_folder / f"{base}{ext}"
43
+ if path.exists():
44
+ return path
45
+ return None
46
+
47
+
48
+ def fuse_views(
49
+ vol0: np.ndarray,
50
+ vol1: np.ndarray,
51
+ mask0: np.ndarray | None = None,
52
+ mask1: np.ndarray | None = None,
53
+ fusion_type: str = "adaptive_blending",
54
+ blending_range: int = 20,
55
+ search_x: tuple[int, int, int] = (-50, 50, 10),
56
+ search_y: tuple[int, int, int] = (-50, 50, 10),
57
+ mask_percentile: float = 1.0,
58
+ mode: Literal["camera", "channel"] = "camera",
59
+ scaling: float = 1.0,
60
+ optimizer: Literal["gradient_descent", "nelder_mead", "bfgs"] = "gradient_descent",
61
+ ) -> tuple[np.ndarray, dict]:
62
+ """
63
+ Fuse two 3D volumes.
64
+
65
+ This is the core fusion function with no file I/O - operates purely on arrays.
66
+ Supports both camera-camera and channel-channel fusion modes.
67
+
68
+ Parameters
69
+ ----------
70
+ vol0 : ndarray
71
+ Reference volume (Z, Y, X)
72
+ vol1 : ndarray
73
+ Moving volume (Z, Y, X)
74
+ mask0 : ndarray or None
75
+ Binary mask for vol0, created if None
76
+ mask1 : ndarray or None
77
+ Binary mask for vol1, created if None
78
+ fusion_type : str
79
+ Blending method: adaptive_blending, geometric, wavelet, average
80
+ blending_range : int
81
+ Width of transition zone in Z-planes
82
+ search_x : tuple
83
+ Registration search range (start, stop, step) for X (camera mode)
84
+ search_y : tuple
85
+ Registration search range (start, stop, step) for Y (camera mode)
86
+ mask_percentile : float
87
+ Percentile for mask creation if masks not provided
88
+ mode : str
89
+ 'camera' for x/y/rotation (XY plane), 'channel' for z/rotation (XZ plane)
90
+ scaling : float
91
+ Axial/lateral pixel ratio for channel mode (z_spacing / xy_spacing)
92
+ optimizer : str
93
+ Optimization algorithm: 'gradient_descent' (MATLAB-equivalent),
94
+ 'nelder_mead', or 'bfgs'
95
+
96
+ Returns
97
+ -------
98
+ fused : ndarray
99
+ Fused volume (Z, Y, X)
100
+ params : dict
101
+ Estimated parameters (transform, intensity correction)
102
+
103
+ Notes
104
+ -----
105
+ Camera mode: estimates x_offset, y_offset, rotation (around Z axis)
106
+ Uses coarse grid search then fine optimization with rotation.
107
+ Equivalent to MATLAB transformCamera.
108
+
109
+ Channel mode: estimates z_offset, rotation (in XZ plane)
110
+ Works on XZ slices to align light sheets from different channels.
111
+ Equivalent to MATLAB transformChannel.
112
+ """
113
+ # create masks if not provided
114
+ if mask0 is None:
115
+ mask0 = create_fusion_mask(vol0, mask_percentile)
116
+ if mask1 is None:
117
+ mask1 = create_fusion_mask(vol1, mask_percentile)
118
+
119
+ # estimate registration based on mode
120
+ if mode == "camera":
121
+ transform = estimate_camera_transform(
122
+ vol0, vol1,
123
+ search_x=search_x,
124
+ search_y=search_y,
125
+ optimizer=optimizer
126
+ )
127
+ # apply camera transform (x/y + rotation per XY slice)
128
+ vol1_transformed = apply_camera_transform(vol1, transform)
129
+ mask1_transformed = apply_camera_transform(mask0.astype(np.float32), transform) > 0.5
130
+
131
+ elif mode == "channel":
132
+ transform = estimate_channel_transform(
133
+ vol0, vol1,
134
+ scaling=scaling,
135
+ optimizer=optimizer
136
+ )
137
+ # apply channel transform (z + rotation per XZ slice)
138
+ vol1_transformed = apply_channel_transform(vol1, transform, scaling)
139
+ mask1_transformed = apply_channel_transform(mask1.astype(np.float32), transform, scaling) > 0.5
140
+
141
+ else:
142
+ raise ValueError(f"Unknown mode: {mode}. Use 'camera' or 'channel'.")
143
+
144
+ # estimate intensity correction
145
+ correction = estimate_intensity_correction(vol0, vol1_transformed, mask0, mask1_transformed)
146
+
147
+ # apply intensity correction
148
+ vol1_corrected = apply_intensity_correction(vol1_transformed, correction)
149
+
150
+ # compute blending masks
151
+ slice_mask0 = create_slice_mask(vol0, mask0)
152
+ slice_mask1 = create_slice_mask(vol1_corrected, mask1_transformed)
153
+ avg_mask = create_average_mask(slice_mask0, slice_mask1, mode="overlap")
154
+ avg_mask = remove_mask_anomalies(avg_mask, vol0.shape[0], blending_range)
155
+
156
+ # blend
157
+ fused = blend_views(
158
+ [vol0, vol1_corrected],
159
+ [avg_mask, avg_mask],
160
+ method=fusion_type,
161
+ blending_range=blending_range
162
+ )
163
+
164
+ params = {
165
+ 'transform': transform,
166
+ 'correction': correction,
167
+ 'mode': mode,
168
+ }
169
+
170
+ return fused, params
171
+
172
+
173
+ def multi_fuse(
174
+ config: ProcessingConfig,
175
+ estimate_params: bool = True,
176
+ apply_fusion: bool = True,
177
+ overwrite: bool = False
178
+ ) -> None:
179
+ """
180
+ Multi-view fusion (matches MATLAB multiFuse).
181
+
182
+ Parameters
183
+ ----------
184
+ config : ProcessingConfig
185
+ Processing configuration with fusion parameters.
186
+ - input_dir: where to read corrected volumes (TM###### folders)
187
+ - output_dir: where to write fused volumes
188
+ estimate_params : bool, default=True
189
+ Estimate transformations and corrections
190
+ apply_fusion : bool, default=True
191
+ Apply fusion and save results
192
+ overwrite : bool, default=False
193
+ Overwrite existing fusion results
194
+ """
195
+ if not config.fusion_camera_pairs and not config.fusion_channel_pairs:
196
+ raise ValueError("No fusion pairs specified. Set fusion_camera_pairs or fusion_channel_pairs.")
197
+
198
+ # determine fusion mode
199
+ # fusion_4view_only: skip camera fusion, use existing camera-fused files
200
+ if getattr(config, 'fusion_4view_only', False) and config.fusion_channel_pairs:
201
+ mode = '4view_only'
202
+ elif config.fusion_camera_pairs and config.fusion_channel_pairs:
203
+ mode = 'full'
204
+ elif config.fusion_channel_pairs:
205
+ mode = 'channel'
206
+ else:
207
+ mode = 'camera'
208
+
209
+ print(f"Fusion mode: {mode}")
210
+ print(f"Input: {config.input_dir}")
211
+ print(f"Output: {config.output_dir}")
212
+ print(f"Estimate params: {estimate_params}, Apply fusion: {apply_fusion}")
213
+
214
+ for timepoint in config.timepoints:
215
+ print(f"\nProcessing timepoint {timepoint}")
216
+
217
+ # read from input_dir
218
+ input_tm_folder = config.input_dir / f"TM{timepoint:06d}"
219
+ if not input_tm_folder.exists():
220
+ print(f" Warning: TM folder not found: {input_tm_folder}, skipping")
221
+ continue
222
+
223
+ # write to output_dir
224
+ output_tm_folder = config.output_dir / f"TM{timepoint:06d}"
225
+
226
+ # step 1: camera fusion (if enabled)
227
+ if config.fusion_camera_pairs:
228
+ for cam0, cam1 in config.fusion_camera_pairs:
229
+ # determine channel for this camera pair
230
+ # cameras 0,1 use channel 1; cameras 2,3 use channel 0
231
+ channel = 1 if cam0 < 2 else 0
232
+ print(f" Fusing cameras {cam0} + {cam1} (channel {channel})")
233
+ _fuse_camera_pair(
234
+ input_tm_folder, output_tm_folder, config.specimen, timepoint,
235
+ cam0, cam1, channel, config, estimate_params, apply_fusion, overwrite
236
+ )
237
+
238
+ # step 2: channel fusion
239
+ if config.fusion_channel_pairs:
240
+ if mode in ('full', '4view_only'):
241
+ # 4-view mode: fuse the camera-fused outputs (CHN00 + CHN01)
242
+ for ch0, ch1 in config.fusion_channel_pairs:
243
+ print(f" Fusing camera-fused channels {ch0} + {ch1} (4-view)")
244
+ _fuse_camera_fused_channels(
245
+ input_tm_folder, output_tm_folder, config.specimen, timepoint,
246
+ ch0, ch1, config, estimate_params, apply_fusion, overwrite
247
+ )
248
+ elif mode == 'channel':
249
+ # channel-only mode: fuse channels per camera (original behavior)
250
+ for ch0, ch1 in config.fusion_channel_pairs:
251
+ for camera in config.cameras:
252
+ print(f" Fusing channels {ch0} + {ch1} for camera {camera}")
253
+ _fuse_channel_pair(
254
+ input_tm_folder, output_tm_folder, config.specimen, timepoint,
255
+ camera, ch0, ch1, config, estimate_params, apply_fusion, overwrite
256
+ )
257
+
258
+ print("\nFusion complete!")
259
+
260
+
261
+ def _fuse_camera_pair(
262
+ input_tm_folder: Path,
263
+ output_tm_folder: Path,
264
+ specimen: int,
265
+ timepoint: int,
266
+ cam0: int,
267
+ cam1: int,
268
+ channel: int,
269
+ config: ProcessingConfig,
270
+ estimate_params: bool,
271
+ apply_fusion: bool,
272
+ overwrite: bool
273
+ ):
274
+ """Fuse two cameras for a single timepoint."""
275
+ # output path
276
+ output_tm_folder.mkdir(parents=True, exist_ok=True)
277
+ fused_file = output_tm_folder / f"SPM{specimen:02d}_TM{timepoint:06d}_CM{cam0:02d}_CM{cam1:02d}_CHN{channel:02d}.{config.output_format}"
278
+
279
+ if fused_file.exists() and not overwrite:
280
+ print(f" Fusion exists: {fused_file.name}, skipping")
281
+ return
282
+
283
+ # find input volumes
284
+ vol0_path = _find_volume_file(input_tm_folder, specimen, timepoint, cam0, channel)
285
+ vol1_path = _find_volume_file(input_tm_folder, specimen, timepoint, cam1, channel)
286
+
287
+ if vol0_path is None:
288
+ print(f" Warning: Camera {cam0} volume not found, skipping")
289
+ return
290
+ if vol1_path is None:
291
+ print(f" Warning: Camera {cam1} volume not found, skipping")
292
+ return
293
+
294
+ # load volumes
295
+ print(f" Loading {vol0_path.name}")
296
+ vol0 = read_volume(vol0_path)
297
+ print(f" Loading {vol1_path.name}")
298
+ vol1 = read_volume(vol1_path)
299
+
300
+ # load or create masks
301
+ mask0_path = _find_mask_file(input_tm_folder, specimen, timepoint, cam0, channel)
302
+ mask1_path = _find_mask_file(input_tm_folder, specimen, timepoint, cam1, channel)
303
+
304
+ if mask0_path and mask0_path.exists():
305
+ print(f" Loading mask {mask0_path.name}")
306
+ mask0 = read_volume(mask0_path) > 0
307
+ print(f" mask0: {mask0.sum()}/{mask0.size} pixels ({100*mask0.sum()/mask0.size:.1f}%)")
308
+ else:
309
+ print(f" Creating mask for camera {cam0}")
310
+ mask0 = create_fusion_mask(vol0, config.mask_percentile)
311
+ print(f" mask0: {mask0.sum()}/{mask0.size} pixels ({100*mask0.sum()/mask0.size:.1f}%)")
312
+
313
+ if mask1_path and mask1_path.exists():
314
+ print(f" Loading mask {mask1_path.name}")
315
+ mask1 = read_volume(mask1_path) > 0
316
+ print(f" mask1: {mask1.sum()}/{mask1.size} pixels ({100*mask1.sum()/mask1.size:.1f}%)")
317
+ else:
318
+ print(f" Creating mask for camera {cam1}")
319
+ mask1 = create_fusion_mask(vol1, config.mask_percentile)
320
+ print(f" mask1: {mask1.sum()}/{mask1.size} pixels ({100*mask1.sum()/mask1.size:.1f}%)")
321
+
322
+ # estimate registration using camera mode (x/y + rotation)
323
+ if estimate_params:
324
+ print(f" Estimating transformation cam{cam0} -> cam{cam1}")
325
+ transform = estimate_camera_transform(
326
+ vol0, vol1,
327
+ search_x=config.fusion_search_offsets_x,
328
+ search_y=config.fusion_search_offsets_y,
329
+ optimizer=getattr(config, 'fusion_optimizer', 'gradient_descent')
330
+ )
331
+
332
+ print(f" Transform: offset=({transform['x_offset']:.2f}, {transform['y_offset']:.2f}), rot={transform['rotation']:.3f}°, corr={transform['correlation']:.3f}")
333
+ else:
334
+ # would need to load from metadata file - for now just estimate
335
+ transform = estimate_camera_transform(vol0, vol1)
336
+
337
+ if not apply_fusion:
338
+ return
339
+
340
+ print(f" Applying fusion")
341
+
342
+ # apply camera transforms (x/y + rotation per XY slice)
343
+ vol1_transformed = apply_camera_transform(vol1, transform)
344
+ mask1_transformed = apply_camera_transform(mask1.astype(np.float32), transform) > 0.5
345
+
346
+ # subtract background from both volumes (MATLAB equivalent)
347
+ background = estimate_background(vol0, vol1_transformed, config.background_percentile)
348
+ vol0_sub = subtract_background(vol0, background)
349
+ vol1_sub = subtract_background(vol1_transformed, background)
350
+ print(f" Background: {background:.1f}")
351
+
352
+ # intensity correction (applied to moving volume)
353
+ correction = estimate_intensity_correction(vol0_sub, vol1_sub, mask0, mask1_transformed)
354
+ print(f" Intensity: factor={correction['factor']:.3f}")
355
+ vol1_corrected = apply_intensity_correction(vol1_sub, correction)
356
+
357
+ # compute blending masks (Y, X) containing Z-values
358
+ slice_mask0 = create_slice_mask(vol0_sub, mask0)
359
+ slice_mask1 = create_slice_mask(vol1_corrected, mask1_transformed)
360
+
361
+ # debug: check slice mask stats
362
+ sm0_valid = (slice_mask0 > 0).sum()
363
+ sm1_valid = (slice_mask1 > 0).sum()
364
+ print(f" slice_mask0: {sm0_valid}/{slice_mask0.size} valid", end="")
365
+ if sm0_valid > 0:
366
+ print(f", range [{slice_mask0[slice_mask0>0].min()}-{slice_mask0[slice_mask0>0].max()}]")
367
+ else:
368
+ print()
369
+ print(f" slice_mask1: {sm1_valid}/{slice_mask1.size} valid", end="")
370
+ if sm1_valid > 0:
371
+ print(f", range [{slice_mask1[slice_mask1>0].min()}-{slice_mask1[slice_mask1>0].max()}]")
372
+ else:
373
+ print()
374
+
375
+ avg_mask = create_average_mask(slice_mask0, slice_mask1, mode="union")
376
+
377
+ # check if mask is valid before removing anomalies
378
+ valid_before = (avg_mask > 0).sum()
379
+ z_size = vol0.shape[0]
380
+ blending = config.fusion_blending_range[1]
381
+
382
+ # automatically reduce blending if volume is too small
383
+ max_blending = (z_size - 1) // 2
384
+ effective_blending = min(blending, max_blending)
385
+ if effective_blending < blending:
386
+ print(f" Note: reduced blending range from {blending} to {effective_blending} (z_size={z_size})")
387
+ blending = effective_blending
388
+
389
+ if valid_before > 0:
390
+ avg_mask = remove_mask_anomalies(avg_mask, z_size, blending)
391
+
392
+ # if mask is empty, use geometric center as fallback
393
+ valid_after = (avg_mask > 0).sum()
394
+ if valid_after == 0:
395
+ center_z = z_size // 2
396
+ print(f" Mask empty after anomaly removal, using geometric center (Z={center_z})")
397
+ avg_mask = np.full((vol0.shape[1], vol0.shape[2]), center_z, dtype=np.uint16)
398
+
399
+ # print mask stats for debugging
400
+ valid_mask = avg_mask > 0
401
+ print(f" Final mask: {valid_mask.sum()}/{avg_mask.size} valid, range [{avg_mask[valid_mask].min()}-{avg_mask[valid_mask].max()}]")
402
+
403
+ # blend with effective blending range
404
+ fused = blend_views(
405
+ [vol0_sub, vol1_corrected],
406
+ [avg_mask, avg_mask],
407
+ method=config.fusion_type,
408
+ blending_range=blending, # use effective blending
409
+ front_flag=config.fusion_front_flag
410
+ )
411
+
412
+ # save
413
+ write_volume(fused, fused_file)
414
+ print(f" Saved to {fused_file}")
415
+
416
+
417
+ def _fuse_channel_pair(
418
+ input_tm_folder: Path,
419
+ output_tm_folder: Path,
420
+ specimen: int,
421
+ timepoint: int,
422
+ camera: int,
423
+ ch0: int,
424
+ ch1: int,
425
+ config: ProcessingConfig,
426
+ estimate_params: bool,
427
+ apply_fusion: bool,
428
+ overwrite: bool
429
+ ):
430
+ """Fuse two channels for a single timepoint."""
431
+ # output path
432
+ output_tm_folder.mkdir(parents=True, exist_ok=True)
433
+ fused_file = output_tm_folder / f"SPM{specimen:02d}_TM{timepoint:06d}_CM{camera:02d}_CHN{ch0:02d}_CHN{ch1:02d}.{config.output_format}"
434
+
435
+ if fused_file.exists() and not overwrite:
436
+ print(f" Fusion exists: {fused_file.name}, skipping")
437
+ return
438
+
439
+ # find input volumes
440
+ vol0_path = _find_volume_file(input_tm_folder, specimen, timepoint, camera, ch0)
441
+ vol1_path = _find_volume_file(input_tm_folder, specimen, timepoint, camera, ch1)
442
+
443
+ if vol0_path is None:
444
+ print(f" Warning: Channel {ch0} volume not found, skipping")
445
+ return
446
+ if vol1_path is None:
447
+ print(f" Warning: Channel {ch1} volume not found, skipping")
448
+ return
449
+
450
+ # load volumes
451
+ vol0 = read_volume(vol0_path)
452
+ vol1 = read_volume(vol1_path)
453
+
454
+ # load or create masks
455
+ mask0_path = _find_mask_file(input_tm_folder, specimen, timepoint, camera, ch0)
456
+ mask1_path = _find_mask_file(input_tm_folder, specimen, timepoint, camera, ch1)
457
+
458
+ if mask0_path and mask0_path.exists():
459
+ mask0 = read_volume(mask0_path) > 0
460
+ else:
461
+ mask0 = create_fusion_mask(vol0, config.mask_percentile)
462
+
463
+ if mask1_path and mask1_path.exists():
464
+ mask1 = read_volume(mask1_path) > 0
465
+ else:
466
+ mask1 = create_fusion_mask(vol1, config.mask_percentile)
467
+
468
+ # estimate registration using channel mode (z-offset + rotation in XZ plane)
469
+ if estimate_params:
470
+ print(f" Estimating transformation ch{ch0} -> ch{ch1}")
471
+ scaling = getattr(config, 'axial_scaling', 1.0)
472
+ transform = estimate_channel_transform(
473
+ vol0, vol1,
474
+ scaling=scaling,
475
+ optimizer=getattr(config, 'fusion_optimizer', 'gradient_descent')
476
+ )
477
+
478
+ print(f" Transform: z_offset={transform['z_offset']:.2f}, rot={transform['rotation']:.3f}°, corr={transform['correlation']:.3f}")
479
+ else:
480
+ transform = estimate_channel_transform(vol0, vol1)
481
+
482
+ if not apply_fusion:
483
+ return
484
+
485
+ print(f" Applying fusion")
486
+
487
+ # apply channel transforms (z + rotation per XZ slice)
488
+ scaling = getattr(config, 'axial_scaling', 1.0)
489
+ vol1_transformed = apply_channel_transform(vol1, transform, scaling)
490
+ mask1_transformed = apply_channel_transform(mask1.astype(np.float32), transform, scaling) > 0.5
491
+
492
+ # intensity correction
493
+ correction = estimate_intensity_correction(vol0, vol1_transformed, mask0, mask1_transformed)
494
+ print(f" Intensity: factor={correction['factor']:.3f}")
495
+ vol1_corrected = apply_intensity_correction(vol1_transformed, correction)
496
+
497
+ # compute blending masks - channel fusion also blends along Z (same as camera)
498
+ slice_mask0 = create_slice_mask(vol0, mask0)
499
+ slice_mask1 = create_slice_mask(vol1_corrected, mask1_transformed)
500
+ avg_mask = create_average_mask(slice_mask0, slice_mask1, mode="union")
501
+
502
+ # remove anomalies based on Z dimension
503
+ z_size = vol0.shape[0]
504
+ blending = config.fusion_blending_range[0]
505
+ max_blending = (z_size - 1) // 2
506
+ effective_blending = min(blending, max_blending)
507
+ if effective_blending < blending:
508
+ print(f" Note: reduced blending range from {blending} to {effective_blending} (z_size={z_size})")
509
+ blending = effective_blending
510
+
511
+ avg_mask = remove_mask_anomalies(avg_mask, z_size, blending)
512
+
513
+ # if mask is empty, use geometric center as fallback
514
+ valid_after = (avg_mask > 0).sum()
515
+ if valid_after == 0:
516
+ center_z = z_size // 2
517
+ print(f" Mask empty, using geometric center (Z={center_z})")
518
+ avg_mask = np.full((vol0.shape[1], vol0.shape[2]), center_z, dtype=np.uint16)
519
+
520
+ fused = blend_views(
521
+ [vol0, vol1_corrected],
522
+ [avg_mask, avg_mask],
523
+ method=config.fusion_type,
524
+ blending_range=blending,
525
+ front_flag=getattr(config, 'fusion_front_flag', 1)
526
+ )
527
+
528
+ write_volume(fused, fused_file)
529
+ print(f" Saved to {fused_file}")
530
+
531
+
532
+ def blend_views(
533
+ views: list[np.ndarray],
534
+ masks: list[np.ndarray],
535
+ method: Literal["adaptive_blending", "geometric", "wavelet", "average"] = "adaptive_blending",
536
+ blending_range: int = 20,
537
+ transition_plane: int | None = None,
538
+ wavelet: str = "db4",
539
+ wavelet_level: int = 5,
540
+ front_flag: int = 1,
541
+ **kwargs
542
+ ) -> np.ndarray:
543
+ """
544
+ Blend multiple views into single fused volume.
545
+
546
+ Parameters
547
+ ----------
548
+ views : list of ndarray
549
+ Volumes to fuse, all same shape (Z, Y, X)
550
+ masks : list of ndarray
551
+ Binary masks or transition planes (Y, X) with Z-values
552
+ method : str, default='adaptive_blending'
553
+ Fusion method: 'adaptive_blending', 'geometric', 'wavelet', 'average'
554
+ blending_range : int, default=20
555
+ Transition zone width in pixels
556
+ transition_plane : int or None, default=None
557
+ For geometric method, Z-index of transition, None uses center
558
+ wavelet : str, default='db4'
559
+ Wavelet basis for wavelet method
560
+ wavelet_level : int, default=5
561
+ Decomposition level for wavelet method
562
+ front_flag : int, default=1
563
+ Which view is dominant in front (low Z). 1=view1, 2=view2
564
+
565
+ Returns
566
+ -------
567
+ ndarray
568
+ Fused volume (Z, Y, X)
569
+ """
570
+ if len(views) != 2:
571
+ raise ValueError("Currently only 2-view fusion is supported")
572
+
573
+ if method == "adaptive_blending":
574
+ return adaptive_blending(views[0], views[1], masks[0], blending_range, front_flag)
575
+ elif method == "geometric":
576
+ return geometric_blending(views[0], views[1], transition_plane, blending_range)
577
+ elif method == "wavelet":
578
+ if not PYWT_AVAILABLE:
579
+ raise ImportError("PyWavelets (pywt) required for wavelet fusion")
580
+ return wavelet_fusion(views[0], views[1], wavelet, wavelet_level)
581
+ elif method == "average":
582
+ return average_fusion(views[0], views[1])
583
+ else:
584
+ raise ValueError(f"Unknown fusion method: {method}")
585
+
586
+
587
+ def adaptive_blending(view1, view2, mask, blending_range=20, front_flag=1):
588
+ """
589
+ Adaptive blending with smooth linear transition along Z axis.
590
+
591
+ Parameters
592
+ ----------
593
+ view1 : ndarray
594
+ Reference view (Z, Y, X)
595
+ view2 : ndarray
596
+ Transformed view (Z, Y, X)
597
+ mask : ndarray
598
+ Transition plane topology (Y, X) containing Z-indices
599
+ blending_range : int, default=20
600
+ Width of blending zone in Z-planes
601
+ front_flag : int, default=1
602
+ Which view is dominant in front (low Z). 1=view1, 2=view2
603
+
604
+ Returns
605
+ -------
606
+ ndarray
607
+ Fused volume (Z, Y, X)
608
+
609
+ Notes
610
+ -----
611
+ MATLAB convention: mask[y,x] contains the Z-plane where transition occurs.
612
+ front_flag controls which view dominates in the "front" (low Z values).
613
+ """
614
+ Z, Y, X = view1.shape
615
+ fused = np.zeros_like(view1, dtype=np.float32)
616
+
617
+ # swap views if front_flag == 2 (view2 should be in front)
618
+ if front_flag == 2:
619
+ view1, view2 = view2, view1
620
+
621
+ # pre-compute weight vectors
622
+ # front: view1 fades from 1.0 to 0.5, view2 fades from 0.0 to 0.5
623
+ w1_front = np.linspace(1.0, 0.5, blending_range, dtype=np.float32)
624
+ w2_front = np.linspace(0.0, 0.5, blending_range, dtype=np.float32)
625
+ # back: view1 fades from 0.5 to 0.0, view2 fades from 0.5 to 1.0
626
+ w1_back = np.linspace(0.5, 0.0, blending_range, dtype=np.float32)
627
+ w2_back = np.linspace(0.5, 1.0, blending_range, dtype=np.float32)
628
+
629
+ # count statistics for debugging
630
+ n_valid = 0
631
+ n_invalid = 0
632
+
633
+ for y in range(Y):
634
+ for x in range(X):
635
+ z0 = int(mask[y, x]) # transition z-plane
636
+
637
+ if z0 == 0 or z0 >= Z:
638
+ # no valid transition - use max of both views (not average, to preserve signal)
639
+ fused[:, y, x] = np.maximum(view1[:, y, x], view2[:, y, x])
640
+ n_invalid += 1
641
+ continue
642
+
643
+ n_valid += 1
644
+
645
+ # front blend zone: z0-blending_range+1 to z0
646
+ z_front_start = max(0, z0 - blending_range + 1)
647
+ z_front_end = z0 + 1 # inclusive of z0
648
+ front_size = z_front_end - z_front_start
649
+
650
+ if front_size > 0 and front_size <= blending_range:
651
+ w1 = w1_front[-front_size:]
652
+ w2 = w2_front[-front_size:]
653
+ fused[z_front_start:z_front_end, y, x] = (
654
+ view1[z_front_start:z_front_end, y, x] * w1 +
655
+ view2[z_front_start:z_front_end, y, x] * w2
656
+ )
657
+
658
+ # back blend zone: z0+1 to z0+blending_range
659
+ z_back_start = z0 + 1
660
+ z_back_end = min(Z, z0 + blending_range + 1)
661
+ back_size = z_back_end - z_back_start
662
+
663
+ if back_size > 0 and back_size <= blending_range:
664
+ w1 = w1_back[:back_size]
665
+ w2 = w2_back[:back_size]
666
+ fused[z_back_start:z_back_end, y, x] = (
667
+ view1[z_back_start:z_back_end, y, x] * w1 +
668
+ view2[z_back_start:z_back_end, y, x] * w2
669
+ )
670
+
671
+ # outside blend zones - use single view
672
+ if z_front_start > 0:
673
+ fused[:z_front_start, y, x] = view1[:z_front_start, y, x]
674
+ if z_back_end < Z:
675
+ fused[z_back_end:, y, x] = view2[z_back_end:, y, x]
676
+
677
+ # print mask statistics for debugging
678
+ if n_invalid > n_valid:
679
+ print(f" Warning: {n_invalid}/{n_valid+n_invalid} pixels have invalid mask (using max)")
680
+
681
+ return fused.astype(view1.dtype)
682
+
683
+
684
+ def geometric_blending(view1, view2, transition_plane=None, blending_range=20):
685
+ """
686
+ Geometric blending with fixed transition plane.
687
+
688
+ Parameters
689
+ ----------
690
+ view1 : ndarray
691
+ Reference view (Z, Y, X)
692
+ view2 : ndarray
693
+ Transformed view (Z, Y, X)
694
+ transition_plane : int or None, default=None
695
+ Z-index of transition, None uses center
696
+ blending_range : int, default=20
697
+ Width of blending zone in z-planes
698
+
699
+ Returns
700
+ -------
701
+ ndarray
702
+ Fused volume (Z, Y, X)
703
+ """
704
+ Z, Y, X = view1.shape
705
+ if transition_plane is None:
706
+ transition_plane = Z // 2
707
+ mask = np.full((Y, X), transition_plane, dtype=np.int32)
708
+ return adaptive_blending(view1, view2, mask, blending_range)
709
+
710
+
711
+ def wavelet_fusion(view1, view2, wavelet="db4", level=5):
712
+ """
713
+ Wavelet-based fusion using frequency decomposition.
714
+
715
+ Parameters
716
+ ----------
717
+ view1 : ndarray
718
+ Reference view (Z, Y, X)
719
+ view2 : ndarray
720
+ Transformed view (Z, Y, X)
721
+ wavelet : str, default='db4'
722
+ Wavelet basis (Daubechies-4)
723
+ level : int, default=5
724
+ Decomposition level
725
+
726
+ Returns
727
+ -------
728
+ ndarray
729
+ Fused volume (Z, Y, X)
730
+ """
731
+ if not PYWT_AVAILABLE:
732
+ raise ImportError("PyWavelets required for wavelet fusion")
733
+
734
+ Z, Y, X = view1.shape
735
+ fused = np.zeros_like(view1, dtype=np.float32)
736
+
737
+ for z in range(Z):
738
+ slice1 = view1[z, :, :].astype(np.float32)
739
+ slice2 = view2[z, :, :].astype(np.float32)
740
+
741
+ coeffs1 = pywt.wavedec2(slice1, wavelet, level=level)
742
+ coeffs2 = pywt.wavedec2(slice2, wavelet, level=level)
743
+
744
+ fused_coeffs = []
745
+ for i, (c1, c2) in enumerate(zip(coeffs1, coeffs2)):
746
+ if i == 0:
747
+ fused_coeffs.append((c1 + c2) / 2)
748
+ else:
749
+ fused_coeffs.append(tuple(np.maximum(c1[j], c2[j]) for j in range(3)))
750
+
751
+ fused[z, :, :] = pywt.waverec2(fused_coeffs, wavelet)
752
+
753
+ return fused.astype(view1.dtype)
754
+
755
+
756
+ def average_fusion(view1, view2):
757
+ """
758
+ Simple arithmetic average of two views.
759
+
760
+ Parameters
761
+ ----------
762
+ view1 : ndarray
763
+ Reference view (Z, Y, X)
764
+ view2 : ndarray
765
+ Transformed view (Z, Y, X)
766
+
767
+ Returns
768
+ -------
769
+ ndarray
770
+ Fused volume (Z, Y, X), fused = (view1 + view2) / 2
771
+ """
772
+ return ((view1.astype(np.float32) + view2.astype(np.float32)) / 2).astype(view1.dtype)
773
+
774
+
775
+ def _find_camera_fused_file(tm_folder: Path, specimen: int, timepoint: int, cam0: int, cam1: int, channel: int) -> Path | None:
776
+ """Find camera-fused volume file."""
777
+ base = f"SPM{specimen:02d}_TM{timepoint:06d}_CM{cam0:02d}_CM{cam1:02d}_CHN{channel:02d}"
778
+ for ext in [".klb", ".tif", ".zarr"]:
779
+ path = tm_folder / f"{base}{ext}"
780
+ if path.exists():
781
+ return path
782
+ return None
783
+
784
+
785
+ def _fuse_camera_fused_channels(
786
+ input_tm_folder: Path,
787
+ output_tm_folder: Path,
788
+ specimen: int,
789
+ timepoint: int,
790
+ ch0: int,
791
+ ch1: int,
792
+ config: ProcessingConfig,
793
+ estimate_params: bool,
794
+ apply_fusion: bool,
795
+ overwrite: bool
796
+ ):
797
+ """
798
+ Fuse camera-fused outputs for 4-view fusion.
799
+
800
+ This fuses the outputs of camera fusion:
801
+ - CM00_CM01_CHN01 (cameras 0+1, channel 1)
802
+ - CM02_CM03_CHN00 (cameras 2+3, channel 0)
803
+
804
+ Uses channel mode (z-offset + rotation in XZ plane) since these are
805
+ now effectively two different light sheet orientations.
806
+
807
+ Masks are loaded from input_tm_folder (pre-computed from raw data),
808
+ not created from background-subtracted camera-fused volumes.
809
+ """
810
+ output_tm_folder.mkdir(parents=True, exist_ok=True)
811
+
812
+ # output filename for 4-view fused result
813
+ fused_file = output_tm_folder / f"SPM{specimen:02d}_TM{timepoint:06d}_4view_fused.{config.output_format}"
814
+
815
+ if fused_file.exists() and not overwrite:
816
+ print(f" Fusion exists: {fused_file.name}, skipping")
817
+ return
818
+
819
+ # find camera-fused inputs in output folder
820
+ # ch0=0, ch1=1: we need CM02_CM03_CHN00 and CM00_CM01_CHN01
821
+ # standard isoview: cameras 0,1 -> CHN01, cameras 2,3 -> CHN00
822
+
823
+ # find file for channel 0 (cameras 2,3)
824
+ vol0_path = _find_camera_fused_file(output_tm_folder, specimen, timepoint, 2, 3, 0)
825
+ # find file for channel 1 (cameras 0,1)
826
+ vol1_path = _find_camera_fused_file(output_tm_folder, specimen, timepoint, 0, 1, 1)
827
+
828
+ if vol0_path is None:
829
+ print(f" Warning: Camera-fused CHN00 (CM02_CM03) not found, skipping")
830
+ return
831
+ if vol1_path is None:
832
+ print(f" Warning: Camera-fused CHN01 (CM00_CM01) not found, skipping")
833
+ return
834
+
835
+ # load volumes
836
+ print(f" Loading {vol0_path.name}")
837
+ vol0 = read_volume(vol0_path)
838
+ print(f" Loading {vol1_path.name}")
839
+ vol1 = read_volume(vol1_path)
840
+
841
+ # flip second channel to match first (for opposite objective views in isoview)
842
+ # this is required because CHN00 and CHN01 image from opposite sides
843
+ flip_h = getattr(config, 'fusion_channel_flip_h', True)
844
+ flip_v = getattr(config, 'fusion_channel_flip_v', True)
845
+ if flip_h or flip_v:
846
+ flips = []
847
+ if flip_h:
848
+ vol1 = np.flip(vol1, axis=2) # flip X (horizontal)
849
+ flips.append("H")
850
+ if flip_v:
851
+ vol1 = np.flip(vol1, axis=1) # flip Y (vertical)
852
+ flips.append("V")
853
+ print(f" Flipped CHN01: {'+'.join(flips)}")
854
+
855
+ # load pre-computed masks from input folder (computed from raw data before background subtraction)
856
+ # for CHN00 (cameras 2,3): use camera 2's mask
857
+ # for CHN01 (cameras 0,1): use camera 0's mask
858
+ mask0_path = _find_mask_file(input_tm_folder, specimen, timepoint, 2, 0)
859
+ mask1_path = _find_mask_file(input_tm_folder, specimen, timepoint, 0, 1)
860
+
861
+ if mask0_path and mask0_path.exists():
862
+ print(f" Loading mask {mask0_path.name}")
863
+ mask0 = read_volume(mask0_path) > 0
864
+ print(f" mask0: {mask0.sum()}/{mask0.size} pixels ({100*mask0.sum()/mask0.size:.1f}%)")
865
+ else:
866
+ print(f" Warning: Mask for CHN00 not found, creating from volume")
867
+ mask0 = create_fusion_mask(vol0, config.mask_percentile)
868
+
869
+ if mask1_path and mask1_path.exists():
870
+ print(f" Loading mask {mask1_path.name}")
871
+ mask1 = read_volume(mask1_path) > 0
872
+ # apply same flip as volume
873
+ if flip_h:
874
+ mask1 = np.flip(mask1, axis=2)
875
+ if flip_v:
876
+ mask1 = np.flip(mask1, axis=1)
877
+ print(f" mask1: {mask1.sum()}/{mask1.size} pixels ({100*mask1.sum()/mask1.size:.1f}%)")
878
+ else:
879
+ print(f" Warning: Mask for CHN01 not found, creating from volume")
880
+ mask1 = create_fusion_mask(vol1, config.mask_percentile)
881
+
882
+ # subtract background BEFORE estimating transform (matches MATLAB)
883
+ # MATLAB subtracts background from dataSlices before calling transformChannel
884
+ background = estimate_background(vol0, vol1, config.background_percentile)
885
+ vol0_sub = subtract_background(vol0, background)
886
+ vol1_sub = subtract_background(vol1, background)
887
+ print(f" Background subtraction: {background:.1f}")
888
+
889
+ # estimate channel transform (z-offset + rotation in XZ plane)
890
+ # use background-subtracted volumes for better correlation
891
+ if estimate_params:
892
+ print(f" Estimating channel transformation")
893
+ scaling = getattr(config, 'axial_scaling', 1.0)
894
+ slab_size = getattr(config, 'channel_slab_size', 5) # use slab for robustness
895
+ transform = estimate_channel_transform(
896
+ vol0_sub, vol1_sub, # use background-subtracted volumes
897
+ scaling=scaling,
898
+ optimizer=getattr(config, 'fusion_optimizer', 'gradient_descent'),
899
+ slab_size=slab_size,
900
+ verbose=True
901
+ )
902
+ print(f" Transform: z_offset={transform['z_offset']:.2f}, rot={transform['rotation']:.3f}°, corr={transform['correlation']:.3f}")
903
+ else:
904
+ scaling = getattr(config, 'axial_scaling', 1.0)
905
+ transform = estimate_channel_transform(vol0_sub, vol1_sub, scaling=scaling, verbose=True)
906
+
907
+ if not apply_fusion:
908
+ return
909
+
910
+ print(f" Applying 4-view fusion")
911
+
912
+ # apply channel transform (z + rotation in XZ plane)
913
+ scaling = getattr(config, 'axial_scaling', 1.0)
914
+ vol1_transformed = apply_channel_transform(vol1_sub, transform, scaling)
915
+ mask1_transformed = apply_channel_transform(mask1.astype(np.float32), transform, scaling) > 0.5
916
+
917
+ # intensity correction (applied to transformed moving volume)
918
+ correction = estimate_intensity_correction(vol0_sub, vol1_transformed, mask0, mask1_transformed)
919
+ print(f" Intensity: factor={correction['factor']:.3f}")
920
+ vol1_corrected = apply_intensity_correction(vol1_transformed, correction)
921
+
922
+ # compute blending masks - channel fusion blends along Z (same as camera fusion)
923
+ slice_mask0 = create_slice_mask(vol0_sub, mask0)
924
+ slice_mask1 = create_slice_mask(vol1_corrected, mask1_transformed)
925
+
926
+ # debug: check slice mask stats
927
+ sm0_valid = (slice_mask0 > 0).sum()
928
+ sm1_valid = (slice_mask1 > 0).sum()
929
+ print(f" slice_mask0: {sm0_valid}/{slice_mask0.size} valid", end="")
930
+ if sm0_valid > 0:
931
+ print(f", Z-range [{slice_mask0[slice_mask0>0].min()}-{slice_mask0[slice_mask0>0].max()}]")
932
+ else:
933
+ print()
934
+ print(f" slice_mask1: {sm1_valid}/{slice_mask1.size} valid", end="")
935
+ if sm1_valid > 0:
936
+ print(f", Z-range [{slice_mask1[slice_mask1>0].min()}-{slice_mask1[slice_mask1>0].max()}]")
937
+ else:
938
+ print()
939
+
940
+ avg_mask = create_average_mask(slice_mask0, slice_mask1, mode="union")
941
+
942
+ # check if mask is valid before removing anomalies
943
+ valid_before = (avg_mask > 0).sum()
944
+ z_size = vol0.shape[0] # use Z dimension for blending
945
+ blending = config.fusion_blending_range[0]
946
+
947
+ # automatically reduce blending if volume is too small
948
+ max_blending = (z_size - 1) // 2
949
+ effective_blending = min(blending, max_blending)
950
+ if effective_blending < blending:
951
+ print(f" Note: reduced blending range from {blending} to {effective_blending} (z_size={z_size})")
952
+ blending = effective_blending
953
+
954
+ if valid_before > 0:
955
+ avg_mask = remove_mask_anomalies(avg_mask, z_size, blending)
956
+
957
+ # if mask is empty, use geometric center as fallback
958
+ valid_after = (avg_mask > 0).sum()
959
+ if valid_after == 0:
960
+ center_z = z_size // 2
961
+ print(f" Mask empty after processing, using geometric center (Z={center_z})")
962
+ avg_mask = np.full((vol0.shape[1], vol0.shape[2]), center_z, dtype=np.uint16) # (Y, X)
963
+
964
+ # print mask stats
965
+ valid_mask = avg_mask > 0
966
+ print(f" Mask: {valid_mask.sum()}/{avg_mask.size} valid, Z-range [{avg_mask[valid_mask].min()}-{avg_mask[valid_mask].max()}]")
967
+
968
+ # blend along Z axis
969
+ fused = blend_views(
970
+ [vol0_sub, vol1_corrected],
971
+ [avg_mask, avg_mask],
972
+ method=config.fusion_type,
973
+ blending_range=blending,
974
+ front_flag=config.fusion_front_flag
975
+ )
976
+
977
+ # save
978
+ write_volume(fused, fused_file)
979
+ print(f" Saved to {fused_file}")