isoview 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
isoview/transforms.py ADDED
@@ -0,0 +1,1115 @@
1
+ from typing import Optional, Literal
2
+
3
+ import numpy as np
4
+ from scipy import ndimage
5
+ from scipy.ndimage import map_coordinates
6
+ from skimage.registration import phase_cross_correlation
7
+
8
+
9
+ # -----------------------------------------------------------------------------
10
+ # matlab-equivalent transform functions
11
+ # -----------------------------------------------------------------------------
12
+
13
+ def transform_channel_slice(
14
+ slice2d: np.ndarray,
15
+ z_offset: float,
16
+ rotation_deg: float,
17
+ scaling: float = 1.0
18
+ ) -> np.ndarray:
19
+ """
20
+ Transform 2D slice (XZ plane) with z-offset and rotation.
21
+
22
+ Equivalent to MATLAB transformChannel - applies rotation around center
23
+ then z translation. Used for channel-channel fusion.
24
+
25
+ Parameters
26
+ ----------
27
+ slice2d : ndarray
28
+ 2D slice (X, Z) from XZ plane (one Y slice)
29
+ z_offset : float
30
+ Z offset in pixels
31
+ rotation_deg : float
32
+ Rotation angle in degrees
33
+ scaling : float
34
+ Axial/lateral pixel ratio for anisotropic voxels
35
+
36
+ Returns
37
+ -------
38
+ ndarray
39
+ Transformed slice (X, Z)
40
+ """
41
+ x_size, z_size = slice2d.shape
42
+
43
+ # create coordinate grids
44
+ xi, zi = np.meshgrid(np.arange(x_size), np.arange(z_size), indexing='ij')
45
+
46
+ # center coordinates
47
+ cx = (x_size - 1) / 2.0
48
+ cz = (z_size - 1) / 2.0
49
+
50
+ # center the grid
51
+ xi_c = xi - cx
52
+ zi_c = zi - cz
53
+
54
+ # rotation matrix (matches MATLAB convention)
55
+ # matlab uses: RR = [cos sin; -sin cos] with XZaux = (RR * [X; Z]')'
56
+ # which gives: X_new = cos*x + sin*z, Z_new = -sin*x + cos*z
57
+ theta = np.deg2rad(rotation_deg)
58
+ cos_t = np.cos(theta)
59
+ sin_t = np.sin(theta)
60
+
61
+ # apply anisotropic scaling, rotate, then unscale
62
+ xi_scaled = xi_c
63
+ zi_scaled = zi_c * scaling
64
+
65
+ # rotate (matlab convention: X_new = cos*x + sin*z, Z_new = -sin*x + cos*z)
66
+ xi_rot = cos_t * xi_scaled + sin_t * zi_scaled
67
+ zi_rot = -sin_t * xi_scaled + cos_t * zi_scaled
68
+
69
+ # unscale and recenter
70
+ xi_new = xi_rot + cx
71
+ zi_new = (zi_rot / scaling) + cz - z_offset
72
+
73
+ # interpolate
74
+ coords = np.array([xi_new.ravel(), zi_new.ravel()])
75
+ transformed = map_coordinates(
76
+ slice2d.astype(np.float64), coords, order=3, mode='constant', cval=0
77
+ ).reshape(slice2d.shape)
78
+
79
+ # mask out-of-bounds
80
+ valid = (xi_new >= 0) & (xi_new < x_size) & (zi_new >= 0) & (zi_new < z_size)
81
+ transformed = transformed * valid
82
+
83
+ return transformed.astype(slice2d.dtype)
84
+
85
+
86
+ def transform_camera_slice(
87
+ slice2d: np.ndarray,
88
+ x_offset: float,
89
+ y_offset: float,
90
+ rotation_deg: float
91
+ ) -> np.ndarray:
92
+ """
93
+ Transform 2D slice (XY plane) with x/y offset and rotation.
94
+
95
+ Equivalent to MATLAB transformCamera - applies rotation around center
96
+ then x/y translation. Used for camera-camera fusion.
97
+
98
+ Parameters
99
+ ----------
100
+ slice2d : ndarray
101
+ 2D slice (X, Y) from XY plane (one Z slice)
102
+ x_offset : float
103
+ X offset in pixels
104
+ y_offset : float
105
+ Y offset in pixels
106
+ rotation_deg : float
107
+ Rotation angle in degrees around Z axis
108
+
109
+ Returns
110
+ -------
111
+ ndarray
112
+ Transformed slice (X, Y)
113
+ """
114
+ x_size, y_size = slice2d.shape
115
+
116
+ # create coordinate grids
117
+ xi, yi = np.meshgrid(np.arange(x_size), np.arange(y_size), indexing='ij')
118
+
119
+ # center coordinates
120
+ cx = (x_size - 1) / 2.0
121
+ cy = (y_size - 1) / 2.0
122
+
123
+ # center the grid
124
+ xi_c = xi - cx
125
+ yi_c = yi - cy
126
+
127
+ # rotation matrix (matches MATLAB convention)
128
+ # matlab uses: RR = [cos -sin; sin cos] with XYaux = [X Y] * RR
129
+ # which gives: X_new = cos*x + sin*y, Y_new = -sin*x + cos*y
130
+ theta = np.deg2rad(rotation_deg)
131
+ cos_t = np.cos(theta)
132
+ sin_t = np.sin(theta)
133
+
134
+ # rotate around center (matlab convention: X_new = cos*x + sin*y, Y_new = -sin*x + cos*y)
135
+ xi_rot = cos_t * xi_c + sin_t * yi_c
136
+ yi_rot = -sin_t * xi_c + cos_t * yi_c
137
+
138
+ # recenter and apply translation
139
+ xi_new = xi_rot + cx - x_offset
140
+ yi_new = yi_rot + cy - y_offset
141
+
142
+ # interpolate
143
+ coords = np.array([xi_new.ravel(), yi_new.ravel()])
144
+ transformed = map_coordinates(
145
+ slice2d.astype(np.float64), coords, order=3, mode='constant', cval=0
146
+ ).reshape(slice2d.shape)
147
+
148
+ # mask out-of-bounds
149
+ valid = (xi_new >= 0) & (xi_new < x_size) & (yi_new >= 0) & (yi_new < y_size)
150
+ transformed = transformed * valid
151
+
152
+ return transformed.astype(slice2d.dtype)
153
+
154
+
155
+ def _correlation_cost_channel(
156
+ params: np.ndarray,
157
+ slice1: np.ndarray,
158
+ slice2: np.ndarray,
159
+ scaling: float
160
+ ) -> float:
161
+ """
162
+ Cost function for channel alignment optimization.
163
+
164
+ Equivalent to MATLAB transformChannel objective function.
165
+ Returns negative correlation (for minimization).
166
+
167
+ Parameters
168
+ ----------
169
+ params : ndarray
170
+ [z_offset, rotation_deg]
171
+ slice1 : ndarray
172
+ Reference slice (X, Z)
173
+ slice2 : ndarray
174
+ Moving slice (X, Z)
175
+ scaling : float
176
+ Axial/lateral pixel ratio
177
+
178
+ Returns
179
+ -------
180
+ float
181
+ Negative normalized correlation coefficient
182
+ """
183
+ z_offset, rotation_deg = params
184
+
185
+ transformed = transform_channel_slice(slice2, z_offset, rotation_deg, scaling)
186
+
187
+ # compute correlation on valid region
188
+ valid = (transformed > 0) & (slice1 > 0)
189
+ if valid.sum() < 100:
190
+ return 1.0 # invalid, return worst case
191
+
192
+ s1 = slice1[valid].astype(np.float64)
193
+ s2 = transformed[valid].astype(np.float64)
194
+
195
+ # normalized correlation
196
+ s1_norm = s1 - s1.mean()
197
+ s2_norm = s2 - s2.mean()
198
+
199
+ std1 = s1_norm.std()
200
+ std2 = s2_norm.std()
201
+
202
+ if std1 < 1e-8 or std2 < 1e-8:
203
+ return 1.0
204
+
205
+ corr = np.mean(s1_norm * s2_norm) / (std1 * std2)
206
+ return -corr
207
+
208
+
209
+ def _correlation_cost_camera(
210
+ params: np.ndarray,
211
+ slice1: np.ndarray,
212
+ slice2: np.ndarray
213
+ ) -> float:
214
+ """
215
+ Cost function for camera alignment optimization.
216
+
217
+ Equivalent to MATLAB transformCamera objective function.
218
+ Returns negative correlation (for minimization).
219
+
220
+ Parameters
221
+ ----------
222
+ params : ndarray
223
+ [x_offset, y_offset, rotation_deg]
224
+ slice1 : ndarray
225
+ Reference slice (X, Y)
226
+ slice2 : ndarray
227
+ Moving slice (X, Y)
228
+
229
+ Returns
230
+ -------
231
+ float
232
+ Negative normalized correlation coefficient
233
+ """
234
+ x_offset, y_offset, rotation_deg = params
235
+
236
+ transformed = transform_camera_slice(slice2, x_offset, y_offset, rotation_deg)
237
+
238
+ # compute correlation on valid region
239
+ valid = (transformed > 0) & (slice1 > 0)
240
+ if valid.sum() < 100:
241
+ return 1.0
242
+
243
+ s1 = slice1[valid].astype(np.float64)
244
+ s2 = transformed[valid].astype(np.float64)
245
+
246
+ s1_norm = s1 - s1.mean()
247
+ s2_norm = s2 - s2.mean()
248
+
249
+ std1 = s1_norm.std()
250
+ std2 = s2_norm.std()
251
+
252
+ if std1 < 1e-8 or std2 < 1e-8:
253
+ return 1.0
254
+
255
+ corr = np.mean(s1_norm * s2_norm) / (std1 * std2)
256
+ return -corr
257
+
258
+
259
+ def gradient_descent_optimize(
260
+ cost_fn,
261
+ x0: np.ndarray,
262
+ step_size: np.ndarray,
263
+ tol_x: float = 1e-3,
264
+ tol_fun: float = 1e-6,
265
+ max_evals: int = 400,
266
+ beta: float = 0.5
267
+ ) -> tuple[np.ndarray, float, int]:
268
+ """
269
+ Gradient descent optimizer with adaptive step size.
270
+
271
+ Equivalent to MATLAB fminuncFA by Fernando Amat.
272
+ Uses numerical gradient and line search with backtracking.
273
+
274
+ Parameters
275
+ ----------
276
+ cost_fn : callable
277
+ Cost function f(x) -> float to minimize
278
+ x0 : ndarray
279
+ Initial parameter vector
280
+ step_size : ndarray
281
+ Step size for each parameter (for gradient estimation)
282
+ tol_x : float
283
+ Convergence tolerance for parameter change
284
+ tol_fun : float
285
+ Convergence tolerance for function value change
286
+ max_evals : int
287
+ Maximum number of function evaluations
288
+ beta : float
289
+ Line search backtracking factor
290
+
291
+ Returns
292
+ -------
293
+ x_sol : ndarray
294
+ Optimal parameters
295
+ f_val : float
296
+ Optimal function value
297
+ exit_flag : int
298
+ 1=converged (function), 2=converged (x), 3=max evals
299
+ """
300
+ n = len(x0)
301
+ x = x0.copy().astype(np.float64)
302
+ step_size = np.asarray(step_size, dtype=np.float64)
303
+
304
+ alpha = 1.0 # adaptive multiplier for line search
305
+ num_eval = 0
306
+
307
+ f0 = cost_fn(x)
308
+ num_eval += 1
309
+
310
+ # track best found
311
+ f_best = f0
312
+ x_best = x.copy()
313
+
314
+ while num_eval < max_evals:
315
+ # compute gradient numerically
316
+ g = np.zeros(n)
317
+ for k in range(n):
318
+ h = np.zeros(n)
319
+ h[k] = step_size[k]
320
+
321
+ fp = cost_fn(x + h)
322
+ num_eval += 1
323
+ if fp < f_best:
324
+ f_best = fp
325
+ x_best = x + h
326
+
327
+ fm = cost_fn(x - h)
328
+ num_eval += 1
329
+ if fm < f_best:
330
+ f_best = fm
331
+ x_best = x - h
332
+
333
+ g[k] = (fp - fm) / (2.0 * step_size[k])
334
+
335
+ # line search
336
+ g_norm = np.linalg.norm(g)
337
+ if g_norm < 1e-10:
338
+ break
339
+
340
+ mu = alpha * 5.0 * np.min(np.abs(step_size / (g + 1e-10)))
341
+ f1 = cost_fn(x - mu * g)
342
+ num_eval += 1
343
+
344
+ aux_alpha = 1.0
345
+ while f1 > f0 and np.linalg.norm(mu * g) > tol_x:
346
+ mu *= beta
347
+ f1 = cost_fn(x - mu * g)
348
+ num_eval += 1
349
+ aux_alpha *= beta
350
+
351
+ # update adaptive alpha
352
+ alpha = 0.3 * alpha + 0.7 * alpha * aux_alpha
353
+
354
+ # check convergence
355
+ if np.linalg.norm(mu * g) < tol_x and f_best + tol_fun >= min(f1, f0):
356
+ if f1 < f0:
357
+ return x - mu * g, f1, 2
358
+ else:
359
+ return x, f0, 2
360
+
361
+ if abs(f1 - f0) / (abs(f0) + 1e-10) < tol_fun and f_best + tol_fun >= min(f1, f0):
362
+ if f1 < f0:
363
+ return x - mu * g, f1, 1
364
+ else:
365
+ return x, f0, 1
366
+
367
+ # update position
368
+ if f1 <= f_best:
369
+ x = x - mu * g
370
+ f0 = f1
371
+ else:
372
+ x = x_best.copy()
373
+ f0 = f_best
374
+
375
+ f_best = f0
376
+ x_best = x.copy()
377
+
378
+ return x, f0, 3
379
+
380
+
381
+ def estimate_channel_transform(
382
+ ref_volume: np.ndarray,
383
+ moving_volume: np.ndarray,
384
+ scaling: float = 1.0,
385
+ y_slice: Optional[int] = None,
386
+ optimizer: Literal["gradient_descent", "nelder_mead", "bfgs"] = "gradient_descent",
387
+ slab_size: int = 0,
388
+ verbose: bool = False
389
+ ) -> dict:
390
+ """
391
+ Estimate channel-channel alignment transform (z-offset + rotation).
392
+
393
+ Equivalent to MATLAB channel alignment in multiFuse.
394
+ Works on XZ slices (one Y plane) to find optimal z-offset and rotation.
395
+
396
+ Parameters
397
+ ----------
398
+ ref_volume : ndarray
399
+ Reference volume (Z, Y, X)
400
+ moving_volume : ndarray
401
+ Moving volume (Z, Y, X)
402
+ scaling : float
403
+ Axial/lateral pixel ratio (z_spacing / xy_spacing)
404
+ y_slice : int or None
405
+ Y slice to use for registration. None uses center slice.
406
+ optimizer : str
407
+ 'gradient_descent' (MATLAB fminuncFA), 'nelder_mead', or 'bfgs'
408
+ slab_size : int
409
+ Number of Y slices to average for more robust estimation. 0 = single slice.
410
+ verbose : bool
411
+ Print debug information.
412
+
413
+ Returns
414
+ -------
415
+ dict
416
+ Contains 'z_offset', 'rotation', 'correlation', 'method'
417
+ """
418
+ z_size, y_size, x_size = ref_volume.shape
419
+
420
+ if y_slice is None:
421
+ y_slice = y_size // 2
422
+
423
+ # extract XZ slice(s) at specified Y
424
+ if slab_size > 0:
425
+ # average over a slab of Y slices for robustness
426
+ half = slab_size // 2
427
+ y_start = max(0, y_slice - half)
428
+ y_end = min(y_size, y_slice + half + 1)
429
+ slice1 = ref_volume[:, y_start:y_end, :].mean(axis=1).T.astype(np.float64)
430
+ slice2 = moving_volume[:, y_start:y_end, :].mean(axis=1).T.astype(np.float64)
431
+ if verbose:
432
+ print(f" Using Y slab [{y_start}:{y_end}] ({y_end-y_start} slices)")
433
+ else:
434
+ # single slice
435
+ slice1 = ref_volume[:, y_slice, :].T.astype(np.float64) # (X, Z)
436
+ slice2 = moving_volume[:, y_slice, :].T.astype(np.float64) # (X, Z)
437
+
438
+ if verbose:
439
+ print(f" slice1: shape={slice1.shape}, range=[{slice1.min():.1f}, {slice1.max():.1f}], mean={slice1.mean():.1f}")
440
+ print(f" slice2: shape={slice2.shape}, range=[{slice2.min():.1f}, {slice2.max():.1f}], mean={slice2.mean():.1f}")
441
+ # compute initial correlation at identity
442
+ valid = (slice1 > 0) & (slice2 > 0)
443
+ if valid.sum() > 100:
444
+ s1 = slice1[valid]
445
+ s2 = slice2[valid]
446
+ s1_n = (s1 - s1.mean()) / (s1.std() + 1e-8)
447
+ s2_n = (s2 - s2.mean()) / (s2.std() + 1e-8)
448
+ init_corr = np.mean(s1_n * s2_n)
449
+ print(f" Initial correlation at identity: {init_corr:.3f} ({valid.sum()} valid pixels)")
450
+
451
+ # initial guess
452
+ x0 = np.array([0.0, 0.0]) # [z_offset, rotation_deg]
453
+ step_size = np.array([1.0, 0.1])
454
+
455
+ cost_fn = lambda p: _correlation_cost_channel(p, slice1, slice2, scaling)
456
+
457
+ if optimizer == "gradient_descent":
458
+ x_sol, f_val, _ = gradient_descent_optimize(cost_fn, x0, step_size)
459
+ elif optimizer == "nelder_mead":
460
+ from scipy.optimize import minimize
461
+ result = minimize(cost_fn, x0, method='Nelder-Mead',
462
+ options={'xatol': 1e-3, 'fatol': 1e-6, 'maxfev': 400})
463
+ x_sol = result.x
464
+ f_val = result.fun
465
+ elif optimizer == "bfgs":
466
+ from scipy.optimize import minimize
467
+ result = minimize(cost_fn, x0, method='BFGS',
468
+ options={'gtol': 1e-6, 'maxiter': 200})
469
+ x_sol = result.x
470
+ f_val = result.fun
471
+ else:
472
+ raise ValueError(f"Unknown optimizer: {optimizer}")
473
+
474
+ return {
475
+ 'z_offset': float(x_sol[0]),
476
+ 'rotation': float(x_sol[1]),
477
+ 'correlation': float(-f_val),
478
+ 'scaling': scaling,
479
+ 'method': f'channel_{optimizer}'
480
+ }
481
+
482
+
483
+ def estimate_camera_transform(
484
+ ref_volume: np.ndarray,
485
+ moving_volume: np.ndarray,
486
+ search_x: tuple[int, int, int] = (-50, 50, 10),
487
+ search_y: tuple[int, int, int] = (-50, 50, 10),
488
+ z_slice: Optional[int] = None,
489
+ optimizer: Literal["gradient_descent", "nelder_mead", "bfgs"] = "gradient_descent"
490
+ ) -> dict:
491
+ """
492
+ Estimate camera-camera alignment transform (x/y offset + rotation).
493
+
494
+ Equivalent to MATLAB camera alignment in multiFuse.
495
+ Uses coarse grid search then fine optimization with rotation.
496
+
497
+ Parameters
498
+ ----------
499
+ ref_volume : ndarray
500
+ Reference volume (Z, Y, X)
501
+ moving_volume : ndarray
502
+ Moving volume (Z, Y, X)
503
+ search_x : tuple
504
+ (start, stop, step) for coarse x search
505
+ search_y : tuple
506
+ (start, stop, step) for coarse y search
507
+ z_slice : int or None
508
+ Z slice to use. None uses MIP (max intensity projection).
509
+ optimizer : str
510
+ 'gradient_descent' (MATLAB fminuncFA), 'nelder_mead', or 'bfgs'
511
+
512
+ Returns
513
+ -------
514
+ dict
515
+ Contains 'x_offset', 'y_offset', 'rotation', 'correlation', 'method'
516
+ """
517
+ z_size, y_size, x_size = ref_volume.shape
518
+
519
+ # get representative XY slice
520
+ if z_slice is None:
521
+ # use max intensity projection
522
+ slice1 = ref_volume.max(axis=0).astype(np.float64).T # (X, Y)
523
+ slice2 = moving_volume.max(axis=0).astype(np.float64).T # (X, Y)
524
+ else:
525
+ slice1 = ref_volume[z_slice, :, :].T.astype(np.float64) # (X, Y)
526
+ slice2 = moving_volume[z_slice, :, :].T.astype(np.float64)
527
+
528
+ # coarse grid search (no rotation)
529
+ x_offsets = np.arange(search_x[0], search_x[1] + 1, search_x[2])
530
+ y_offsets = np.arange(search_y[0], search_y[1] + 1, search_y[2])
531
+
532
+ best_corr = -np.inf
533
+ best_x = 0.0
534
+ best_y = 0.0
535
+
536
+ intensity_ref = slice1.sum()
537
+
538
+ for x_off in x_offsets:
539
+ # shift in x
540
+ if x_off >= 0:
541
+ shifted_x = np.zeros_like(slice2)
542
+ shifted_x[x_off:, :] = slice2[:-x_off or None, :]
543
+ else:
544
+ shifted_x = np.zeros_like(slice2)
545
+ shifted_x[:x_off, :] = slice2[-x_off:, :]
546
+
547
+ for y_off in y_offsets:
548
+ # shift in y
549
+ if y_off >= 0:
550
+ shifted = np.zeros_like(shifted_x)
551
+ shifted[:, y_off:] = shifted_x[:, :-y_off or None]
552
+ else:
553
+ shifted = np.zeros_like(shifted_x)
554
+ shifted[:, :y_off] = shifted_x[:, -y_off:]
555
+
556
+ intensity_shifted = shifted.sum()
557
+ if intensity_shifted < 1e-10:
558
+ continue
559
+
560
+ # normalized intensity product correlation (MATLAB style)
561
+ product = slice1 * shifted
562
+ corr = product.sum() / (intensity_ref * intensity_shifted + 1e-10)
563
+
564
+ if corr > best_corr:
565
+ best_corr = corr
566
+ best_x = float(x_off)
567
+ best_y = float(y_off)
568
+
569
+ # fine optimization with rotation starting from coarse result
570
+ x0 = np.array([best_x, best_y, 0.0]) # [x_offset, y_offset, rotation_deg]
571
+ step_size = np.array([1.0, 1.0, 0.1])
572
+
573
+ cost_fn = lambda p: _correlation_cost_camera(p, slice1, slice2)
574
+
575
+ if optimizer == "gradient_descent":
576
+ x_sol, f_val, _ = gradient_descent_optimize(cost_fn, x0, step_size)
577
+ elif optimizer == "nelder_mead":
578
+ from scipy.optimize import minimize
579
+ result = minimize(cost_fn, x0, method='Nelder-Mead',
580
+ options={'xatol': 1e-3, 'fatol': 1e-6, 'maxfev': 400})
581
+ x_sol = result.x
582
+ f_val = result.fun
583
+ elif optimizer == "bfgs":
584
+ from scipy.optimize import minimize
585
+ result = minimize(cost_fn, x0, method='BFGS',
586
+ options={'gtol': 1e-6, 'maxiter': 200})
587
+ x_sol = result.x
588
+ f_val = result.fun
589
+ else:
590
+ raise ValueError(f"Unknown optimizer: {optimizer}")
591
+
592
+ return {
593
+ 'x_offset': float(x_sol[0]),
594
+ 'y_offset': float(x_sol[1]),
595
+ 'z_offset': 0.0,
596
+ 'rotation': float(x_sol[2]),
597
+ 'correlation': float(-f_val),
598
+ 'coarse_x': best_x,
599
+ 'coarse_y': best_y,
600
+ 'method': f'camera_{optimizer}'
601
+ }
602
+
603
+
604
+ def apply_channel_transform(
605
+ volume: np.ndarray,
606
+ transform: dict,
607
+ scaling: float = 1.0
608
+ ) -> np.ndarray:
609
+ """
610
+ Apply channel transform (z-offset + rotation) to full volume.
611
+
612
+ Transforms each XZ slice (for each Y) with the same z-offset and rotation.
613
+
614
+ Parameters
615
+ ----------
616
+ volume : ndarray
617
+ Input volume (Z, Y, X)
618
+ transform : dict
619
+ Contains 'z_offset' and 'rotation'
620
+ scaling : float
621
+ Axial/lateral pixel ratio
622
+
623
+ Returns
624
+ -------
625
+ ndarray
626
+ Transformed volume (Z, Y, X)
627
+ """
628
+ z_offset = transform.get('z_offset', 0.0)
629
+ rotation = transform.get('rotation', 0.0)
630
+ scaling = transform.get('scaling', scaling)
631
+
632
+ z_size, y_size, x_size = volume.shape
633
+ result = np.zeros_like(volume)
634
+
635
+ for y in range(y_size):
636
+ # extract XZ slice, transform, put back
637
+ slice_xz = volume[:, y, :].T # (X, Z)
638
+ transformed = transform_channel_slice(slice_xz, z_offset, rotation, scaling)
639
+ result[:, y, :] = transformed.T # back to (Z, X)
640
+
641
+ return result
642
+
643
+
644
+ def apply_camera_transform(
645
+ volume: np.ndarray,
646
+ transform: dict
647
+ ) -> np.ndarray:
648
+ """
649
+ Apply camera transform (x/y offset + rotation) to full volume.
650
+
651
+ Transforms each XY slice (for each Z) with the same x/y offset and rotation.
652
+
653
+ Parameters
654
+ ----------
655
+ volume : ndarray
656
+ Input volume (Z, Y, X)
657
+ transform : dict
658
+ Contains 'x_offset', 'y_offset', and 'rotation'
659
+
660
+ Returns
661
+ -------
662
+ ndarray
663
+ Transformed volume (Z, Y, X)
664
+ """
665
+ x_offset = transform.get('x_offset', 0.0)
666
+ y_offset = transform.get('y_offset', 0.0)
667
+ rotation = transform.get('rotation', 0.0)
668
+
669
+ z_size, y_size, x_size = volume.shape
670
+ result = np.zeros_like(volume)
671
+
672
+ for z in range(z_size):
673
+ # extract XY slice, transform, put back
674
+ slice_xy = volume[z, :, :].T # (X, Y)
675
+ transformed = transform_camera_slice(slice_xy, x_offset, y_offset, rotation)
676
+ result[z, :, :] = transformed.T # back to (Y, X)
677
+
678
+ return result
679
+
680
+
681
+ # -----------------------------------------------------------------------------
682
+ # original utility functions
683
+ # -----------------------------------------------------------------------------
684
+
685
+ def rotate_volume(volume: np.ndarray, rotation: int) -> np.ndarray:
686
+ """
687
+ Rotate 3d volume by 90 degrees.
688
+
689
+ Parameters
690
+ ----------
691
+ volume : ndarray
692
+ 3D image stack (z, y, x)
693
+ rotation : int
694
+ 0 (none), 1 (90 deg cw), -1 (90 deg ccw)
695
+
696
+ Returns
697
+ -------
698
+ ndarray
699
+ Rotated volume
700
+ """
701
+ if rotation == 0:
702
+ return volume
703
+ elif rotation == 1:
704
+ return np.rot90(volume, k=-1, axes=(1, 2))
705
+ elif rotation == -1:
706
+ return np.rot90(volume, k=1, axes=(1, 2))
707
+ else:
708
+ raise ValueError(f"invalid rotation: {rotation}")
709
+
710
+
711
+ def flip_volume(volume: np.ndarray, horizontal: bool, vertical: bool) -> np.ndarray:
712
+ """
713
+ Flip volume horizontally and/or vertically.
714
+
715
+ Parameters
716
+ ----------
717
+ volume : ndarray
718
+ 3D image stack (z, y, x)
719
+ horizontal : bool
720
+ Flip left-right
721
+ vertical : bool
722
+ Flip top-bottom
723
+
724
+ Returns
725
+ -------
726
+ ndarray
727
+ Flipped volume
728
+ """
729
+ result = volume
730
+ if horizontal:
731
+ result = result[:, :, ::-1]
732
+ if vertical:
733
+ result = result[:, ::-1, :]
734
+ return result
735
+
736
+
737
+ def crop_volume(
738
+ volume: np.ndarray,
739
+ top: int = 0,
740
+ left: int = 0,
741
+ height: Optional[int] = None,
742
+ width: Optional[int] = None,
743
+ front: int = 0,
744
+ depth: Optional[int] = None,
745
+ ) -> np.ndarray:
746
+ """
747
+ Crop 3d volume to specified roi.
748
+
749
+ Parameters
750
+ ----------
751
+ volume : ndarray
752
+ 3D image stack (z, y, x)
753
+ top : int, default=0
754
+ Crop start y
755
+ left : int, default=0
756
+ Crop start x
757
+ height : int or None, default=None
758
+ Crop height, None uses full
759
+ width : int or None, default=None
760
+ Crop width, None uses full
761
+ front : int, default=0
762
+ Crop start z
763
+ depth : int or None, default=None
764
+ Crop depth, None uses full
765
+
766
+ Returns
767
+ -------
768
+ ndarray
769
+ Cropped volume
770
+ """
771
+ z, y, x = volume.shape
772
+
773
+ if height is None:
774
+ height = y - top
775
+ if width is None:
776
+ width = x - left
777
+ if depth is None:
778
+ depth = z - front
779
+
780
+ return volume[
781
+ front : front + depth,
782
+ top : top + height,
783
+ left : left + width,
784
+ ]
785
+
786
+
787
+ def apply_mask(volume: np.ndarray, mask: np.ndarray) -> np.ndarray:
788
+ """
789
+ Apply binary mask to volume.
790
+
791
+ Parameters
792
+ ----------
793
+ volume : ndarray
794
+ 3D image stack
795
+ mask : ndarray
796
+ Binary mask (0 or 1)
797
+
798
+ Returns
799
+ -------
800
+ ndarray
801
+ Masked volume
802
+ """
803
+ return volume * mask
804
+
805
+
806
+ def estimate_registration(
807
+ ref_volume: np.ndarray,
808
+ moving_volume: np.ndarray,
809
+ ref_mask: np.ndarray,
810
+ moving_mask: np.ndarray,
811
+ search_x: tuple[int, int, int] = (-50, 50, 10),
812
+ search_y: tuple[int, int, int] = (-50, 50, 10),
813
+ method: Literal["phase_cross_correlation", "correlation"] = "phase_cross_correlation",
814
+ slab_size: int = 0
815
+ ) -> dict:
816
+ """
817
+ Estimate registration transformation between two views.
818
+
819
+ Parameters
820
+ ----------
821
+ ref_volume : ndarray
822
+ Reference view (Z, Y, X)
823
+ moving_volume : ndarray
824
+ View to transform (Z, Y, X)
825
+ ref_mask : ndarray
826
+ Binary mask for ref (Z, Y, X)
827
+ moving_mask : ndarray
828
+ Binary mask for moving (Z, Y, X)
829
+ search_x : tuple of int, default=(-50, 50, 10)
830
+ (start, stop, step) for x search range
831
+ search_y : tuple of int, default=(-50, 50, 10)
832
+ (start, stop, step) for y search range
833
+ method : str, default='phase_cross_correlation'
834
+ 'phase_cross_correlation' (subpixel, fast) or 'correlation' (coarse search)
835
+ slab_size : int, default=0
836
+ Adaptive slab size for correlation, 0 for slice mode, <0 for full 3D
837
+
838
+ Returns
839
+ -------
840
+ dict
841
+ Contains 'x_offset', 'y_offset', 'z_offset', 'rotation', 'correlation', 'method'
842
+
843
+ Notes
844
+ -----
845
+ Phase cross correlation uses FFT for fast subpixel registration.
846
+ Correlation method searches discrete grid (slower but more robust to rotation).
847
+ Masks focus registration on valid regions.
848
+ Slab mode processes volume in Z slabs for memory efficiency.
849
+ """
850
+ # Apply slab-based processing if enabled
851
+ if slab_size > 0:
852
+ return _estimate_registration_slab(
853
+ ref_volume, moving_volume, ref_mask, moving_mask,
854
+ search_x, search_y, method, slab_size
855
+ )
856
+
857
+ if method == "phase_cross_correlation":
858
+ return _estimate_registration_phase(ref_volume, moving_volume, ref_mask, moving_mask)
859
+ elif method == "correlation":
860
+ return _estimate_registration_correlation(
861
+ ref_volume, moving_volume, ref_mask, moving_mask, search_x, search_y
862
+ )
863
+ else:
864
+ raise ValueError(f"Unknown method: {method}")
865
+
866
+
867
+ def _estimate_registration_phase(
868
+ ref_volume: np.ndarray,
869
+ moving_volume: np.ndarray,
870
+ ref_mask: np.ndarray,
871
+ moving_mask: np.ndarray
872
+ ) -> dict:
873
+ """Estimate registration using phase cross-correlation, FFT-based subpixel method."""
874
+ # Apply masks
875
+ ref_masked = ref_volume * ref_mask
876
+ moving_masked = moving_volume * moving_mask
877
+
878
+ # Estimate shift using phase cross-correlation
879
+ shift, error, _ = phase_cross_correlation(
880
+ ref_masked, moving_masked, upsample_factor=10
881
+ )
882
+
883
+ z_offset, y_offset, x_offset = shift
884
+ correlation = 1.0 / (1.0 + error) if error > 0 else 1.0
885
+
886
+ return {
887
+ 'x_offset': float(x_offset),
888
+ 'y_offset': float(y_offset),
889
+ 'z_offset': float(z_offset),
890
+ 'rotation': 0.0,
891
+ 'correlation': float(correlation),
892
+ 'method': 'phase_cross_correlation'
893
+ }
894
+
895
+
896
+ def _estimate_registration_correlation(
897
+ ref_volume: np.ndarray,
898
+ moving_volume: np.ndarray,
899
+ ref_mask: np.ndarray,
900
+ moving_mask: np.ndarray,
901
+ search_x: tuple[int, int, int],
902
+ search_y: tuple[int, int, int]
903
+ ) -> dict:
904
+ """Estimate registration using correlation search over discrete grid."""
905
+ ref_masked = ref_volume * ref_mask
906
+ moving_masked = moving_volume * moving_mask
907
+
908
+ x_start, x_stop, x_step = search_x
909
+ y_start, y_stop, y_step = search_y
910
+ x_offsets = np.arange(x_start, x_stop + 1, x_step)
911
+ y_offsets = np.arange(y_start, y_stop + 1, y_step)
912
+
913
+ best_corr = -np.inf
914
+ best_x = 0.0
915
+ best_y = 0.0
916
+
917
+ for x_off in x_offsets:
918
+ for y_off in y_offsets:
919
+ shifted = ndimage.shift(
920
+ moving_masked, shift=(0, y_off, x_off), order=1, mode='constant', cval=0
921
+ )
922
+
923
+ overlap_mask = (ref_masked > 0) & (shifted > 0)
924
+ if overlap_mask.sum() == 0:
925
+ continue
926
+
927
+ ref_overlap = ref_masked[overlap_mask].astype(np.float64)
928
+ mov_overlap = shifted[overlap_mask].astype(np.float64)
929
+
930
+ ref_norm = (ref_overlap - ref_overlap.mean()) / (ref_overlap.std() + 1e-8)
931
+ mov_norm = (mov_overlap - mov_overlap.mean()) / (mov_overlap.std() + 1e-8)
932
+ corr = np.mean(ref_norm * mov_norm)
933
+
934
+ if corr > best_corr:
935
+ best_corr = corr
936
+ best_x = x_off
937
+ best_y = y_off
938
+
939
+ return {
940
+ 'x_offset': float(best_x),
941
+ 'y_offset': float(best_y),
942
+ 'z_offset': 0.0,
943
+ 'rotation': 0.0,
944
+ 'correlation': float(best_corr),
945
+ 'method': 'correlation'
946
+ }
947
+
948
+
949
+ def apply_registration(
950
+ volume: np.ndarray,
951
+ transform: dict,
952
+ order: int = 3
953
+ ) -> np.ndarray:
954
+ """
955
+ Apply registration transformation to volume.
956
+
957
+ Parameters
958
+ ----------
959
+ volume : ndarray
960
+ Input volume (Z, Y, X)
961
+ transform : dict
962
+ Dict from estimate_registration containing x_offset, y_offset, z_offset, rotation
963
+ order : int, default=3
964
+ Interpolation order: 0=nearest, 1=linear, 3=cubic
965
+
966
+ Returns
967
+ -------
968
+ ndarray
969
+ Transformed volume (Z, Y, X)
970
+
971
+ Notes
972
+ -----
973
+ Applies rotation first (per Z-slice in XY plane), then translation.
974
+ """
975
+ x_offset = transform.get('x_offset', 0.0)
976
+ y_offset = transform.get('y_offset', 0.0)
977
+ z_offset = transform.get('z_offset', 0.0)
978
+ rotation = transform.get('rotation', 0.0)
979
+
980
+ result = volume.copy()
981
+
982
+ # Apply rotation if non-zero (per Z-slice in XY plane)
983
+ if abs(rotation) > 1e-6:
984
+ Z = result.shape[0]
985
+ rotated = np.zeros_like(result)
986
+ for z in range(Z):
987
+ rotated[z, :, :] = ndimage.rotate(
988
+ result[z, :, :], rotation, reshape=False, order=order, mode='constant', cval=0
989
+ )
990
+ result = rotated
991
+
992
+ # Apply translation
993
+ # use mode='nearest' to avoid zeroing boundary slices
994
+ if abs(x_offset) > 1e-6 or abs(y_offset) > 1e-6 or abs(z_offset) > 1e-6:
995
+ result = ndimage.shift(
996
+ result, shift=(z_offset, y_offset, x_offset), order=order, mode='nearest'
997
+ )
998
+
999
+ return result
1000
+
1001
+
1002
+ def apply_registration_to_mask(mask: np.ndarray, transform: dict) -> np.ndarray:
1003
+ """
1004
+ Apply registration transformation to binary mask.
1005
+
1006
+ Parameters
1007
+ ----------
1008
+ mask : ndarray
1009
+ Binary mask (Z, Y, X) or (X, Z) transition plane topology
1010
+ transform : dict
1011
+ Transformation parameters
1012
+
1013
+ Returns
1014
+ -------
1015
+ ndarray
1016
+ Transformed mask, same shape as input
1017
+
1018
+ Notes
1019
+ -----
1020
+ Uses order=1 linear interpolation then thresholds at 0.5 to preserve binary nature.
1021
+ """
1022
+ transformed = apply_registration(mask.astype(np.float32), transform, order=1)
1023
+ return (transformed > 0.5).astype(mask.dtype)
1024
+
1025
+
1026
+ def _estimate_registration_slab(
1027
+ ref_volume: np.ndarray,
1028
+ moving_volume: np.ndarray,
1029
+ ref_mask: np.ndarray,
1030
+ moving_mask: np.ndarray,
1031
+ search_x: tuple[int, int, int],
1032
+ search_y: tuple[int, int, int],
1033
+ method: str,
1034
+ slab_size: int
1035
+ ) -> dict:
1036
+ """
1037
+ Estimate registration using slab-based processing.
1038
+
1039
+ Parameters
1040
+ ----------
1041
+ ref_volume : ndarray
1042
+ Reference view (Z, Y, X)
1043
+ moving_volume : ndarray
1044
+ View to transform (Z, Y, X)
1045
+ ref_mask : ndarray
1046
+ Binary mask for ref (Z, Y, X)
1047
+ moving_mask : ndarray
1048
+ Binary mask for moving (Z, Y, X)
1049
+ search_x : tuple of int
1050
+ (start, stop, step) for x search range
1051
+ search_y : tuple of int
1052
+ (start, stop, step) for y search range
1053
+ method : str
1054
+ Registration method
1055
+ slab_size : int
1056
+ Slab thickness in Z
1057
+
1058
+ Returns
1059
+ -------
1060
+ dict
1061
+ Contains averaged registration parameters
1062
+
1063
+ Notes
1064
+ -----
1065
+ Processes volume in Z slabs, estimates registration for each slab,
1066
+ then averages results. More memory-efficient for large volumes.
1067
+ """
1068
+ Z = ref_volume.shape[0]
1069
+
1070
+ # Collect results from each slab
1071
+ slab_results = []
1072
+
1073
+ for z_start in range(0, Z, slab_size):
1074
+ z_end = min(z_start + slab_size, Z)
1075
+
1076
+ # Extract slab
1077
+ ref_slab = ref_volume[z_start:z_end, :, :]
1078
+ moving_slab = moving_volume[z_start:z_end, :, :]
1079
+ ref_mask_slab = ref_mask[z_start:z_end, :, :]
1080
+ moving_mask_slab = moving_mask[z_start:z_end, :, :]
1081
+
1082
+ # Estimate registration for this slab
1083
+ if method == "phase_cross_correlation":
1084
+ result = _estimate_registration_phase(
1085
+ ref_slab, moving_slab, ref_mask_slab, moving_mask_slab
1086
+ )
1087
+ elif method == "correlation":
1088
+ result = _estimate_registration_correlation(
1089
+ ref_slab, moving_slab, ref_mask_slab, moving_mask_slab,
1090
+ search_x, search_y
1091
+ )
1092
+ else:
1093
+ raise ValueError(f"Unknown method: {method}")
1094
+
1095
+ slab_results.append(result)
1096
+
1097
+ # Average results across slabs (weighted by correlation)
1098
+ correlations = np.array([r['correlation'] for r in slab_results])
1099
+ total_corr = correlations.sum()
1100
+
1101
+ if total_corr > 0:
1102
+ weights = correlations / total_corr
1103
+ else:
1104
+ weights = np.ones(len(slab_results)) / len(slab_results)
1105
+
1106
+ averaged = {
1107
+ 'x_offset': sum(w * r['x_offset'] for w, r in zip(weights, slab_results)),
1108
+ 'y_offset': sum(w * r['y_offset'] for w, r in zip(weights, slab_results)),
1109
+ 'z_offset': sum(w * r['z_offset'] for w, r in zip(weights, slab_results)),
1110
+ 'rotation': sum(w * r['rotation'] for w, r in zip(weights, slab_results)),
1111
+ 'correlation': correlations.mean(),
1112
+ 'method': f"{method}_slab"
1113
+ }
1114
+
1115
+ return averaged