isoview 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
isoview/viz.py ADDED
@@ -0,0 +1,723 @@
1
+ """Visualization utilities for isoview processing pipeline."""
2
+
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+ import numpy as np
6
+ from matplotlib import pyplot as plt
7
+ from matplotlib.colors import LinearSegmentedColormap
8
+
9
+
10
+ # Default style settings
11
+ STYLE = {
12
+ "facecolor": "black",
13
+ "text_color": "white",
14
+ "grid_color": "gray",
15
+ "grid_alpha": 0.3,
16
+ "dpi": 150,
17
+ "font_size": 10,
18
+ "title_size": 12,
19
+ "label_size": 10,
20
+ }
21
+
22
+
23
+ def _apply_dark_style(ax, fig=None):
24
+ """Apply dark theme to axes."""
25
+ ax.set_facecolor(STYLE["facecolor"])
26
+ ax.tick_params(colors=STYLE["text_color"], labelsize=STYLE["font_size"])
27
+ for spine in ax.spines.values():
28
+ spine.set_color(STYLE["text_color"])
29
+ if fig is not None:
30
+ fig.patch.set_facecolor(STYLE["facecolor"])
31
+
32
+
33
+ def _save_or_show(fig, save_path: Optional[Union[str, Path]] = None):
34
+ """Save figure to file or display."""
35
+ if save_path:
36
+ plt.savefig(
37
+ save_path,
38
+ dpi=STYLE["dpi"],
39
+ facecolor=fig.get_facecolor(),
40
+ bbox_inches="tight"
41
+ )
42
+ plt.close(fig)
43
+ else:
44
+ plt.show()
45
+
46
+
47
+ def plot_projections(
48
+ volume: np.ndarray,
49
+ save_path: Optional[Union[str, Path]] = None,
50
+ title: str = "Volume Projections",
51
+ cmap: str = "gray"
52
+ ) -> None:
53
+ """
54
+ Plot maximum intensity projections along all three axes.
55
+
56
+ Parameters
57
+ ----------
58
+ volume : ndarray
59
+ 3D volume (Z, Y, X)
60
+ save_path : str or Path, optional
61
+ Path to save figure
62
+ title : str, default='Volume Projections'
63
+ Figure title
64
+ cmap : str, default='gray'
65
+ Colormap
66
+ """
67
+ fig, axes = plt.subplots(1, 3, figsize=(12, 4), facecolor=STYLE["facecolor"])
68
+
69
+ proj_z = np.max(volume, axis=0) # Z projection -> (Y, X)
70
+ proj_y = np.max(volume, axis=1) # Y projection -> (Z, X)
71
+ proj_x = np.max(volume, axis=2) # X projection -> (Z, Y)
72
+
73
+ titles = ["XY (max Z)", "XZ (max Y)", "YZ (max X)"]
74
+ projs = [proj_z, proj_y, proj_x]
75
+
76
+ for ax, proj, t in zip(axes, projs, titles):
77
+ _apply_dark_style(ax)
78
+ ax.imshow(proj, cmap=cmap, aspect="auto")
79
+ ax.set_title(t, color=STYLE["text_color"], fontsize=STYLE["label_size"], fontweight="bold")
80
+ ax.set_xticks([])
81
+ ax.set_yticks([])
82
+
83
+ fig.suptitle(title, color=STYLE["text_color"], fontsize=STYLE["title_size"], fontweight="bold")
84
+ plt.tight_layout()
85
+ _save_or_show(fig, save_path)
86
+
87
+
88
+ def plot_histogram(
89
+ volume: np.ndarray,
90
+ save_path: Optional[Union[str, Path]] = None,
91
+ title: str = "Intensity Histogram",
92
+ bins: int = 256,
93
+ log_scale: bool = True,
94
+ percentile_lines: tuple = (1, 5, 95, 99)
95
+ ) -> None:
96
+ """
97
+ Plot intensity histogram with percentile markers.
98
+
99
+ Parameters
100
+ ----------
101
+ volume : ndarray
102
+ 3D volume
103
+ save_path : str or Path, optional
104
+ Path to save figure
105
+ title : str, default='Intensity Histogram'
106
+ Figure title
107
+ bins : int, default=256
108
+ Number of histogram bins
109
+ log_scale : bool, default=True
110
+ Use log scale for y-axis
111
+ percentile_lines : tuple, default=(1, 5, 95, 99)
112
+ Percentile values to mark with vertical lines
113
+ """
114
+ fig, ax = plt.subplots(figsize=(8, 4), facecolor=STYLE["facecolor"])
115
+ _apply_dark_style(ax, fig)
116
+
117
+ flat = volume.ravel()
118
+ ax.hist(flat, bins=bins, color="#3498db", alpha=0.8, edgecolor="none")
119
+
120
+ if log_scale:
121
+ ax.set_yscale("log")
122
+
123
+ # Add percentile lines
124
+ colors = ["#e74c3c", "#f39c12", "#f39c12", "#e74c3c"]
125
+ for p, c in zip(percentile_lines, colors):
126
+ val = np.percentile(flat, p)
127
+ ax.axvline(val, color=c, linestyle="--", alpha=0.8, linewidth=1.5)
128
+ ax.text(val, ax.get_ylim()[1] * 0.9, f"{p}%",
129
+ color=c, fontsize=8, ha="center", va="top")
130
+
131
+ ax.set_xlabel("Intensity", color=STYLE["text_color"], fontsize=STYLE["label_size"], fontweight="bold")
132
+ ax.set_ylabel("Count", color=STYLE["text_color"], fontsize=STYLE["label_size"], fontweight="bold")
133
+ ax.set_title(title, color=STYLE["text_color"], fontsize=STYLE["title_size"], fontweight="bold")
134
+
135
+ ax.grid(True, alpha=STYLE["grid_alpha"], color=STYLE["grid_color"])
136
+ _save_or_show(fig, save_path)
137
+
138
+
139
+ def plot_segmentation_mask(
140
+ volume: np.ndarray,
141
+ mask: np.ndarray,
142
+ save_path: Optional[Union[str, Path]] = None,
143
+ title: str = "Segmentation Mask",
144
+ slice_idx: Optional[int] = None,
145
+ alpha: float = 0.4
146
+ ) -> None:
147
+ """
148
+ Plot volume slice with segmentation mask overlay.
149
+
150
+ Parameters
151
+ ----------
152
+ volume : ndarray
153
+ 3D volume (Z, Y, X)
154
+ mask : ndarray
155
+ Binary mask (Z, Y, X)
156
+ save_path : str or Path, optional
157
+ Path to save figure
158
+ title : str, default='Segmentation Mask'
159
+ Figure title
160
+ slice_idx : int, optional
161
+ Z-slice to display, None uses middle slice
162
+ alpha : float, default=0.4
163
+ Mask overlay transparency
164
+ """
165
+ if slice_idx is None:
166
+ slice_idx = volume.shape[0] // 2
167
+
168
+ fig, axes = plt.subplots(1, 3, figsize=(12, 4), facecolor=STYLE["facecolor"])
169
+
170
+ vol_slice = volume[slice_idx]
171
+ mask_slice = mask[slice_idx]
172
+
173
+ # Normalize volume for display
174
+ vmin, vmax = np.percentile(vol_slice, [1, 99])
175
+ vol_norm = np.clip((vol_slice - vmin) / (vmax - vmin + 1e-8), 0, 1)
176
+
177
+ # Original
178
+ _apply_dark_style(axes[0])
179
+ axes[0].imshow(vol_norm, cmap="gray")
180
+ axes[0].set_title("Original", color=STYLE["text_color"], fontsize=STYLE["label_size"], fontweight="bold")
181
+ axes[0].set_xticks([])
182
+ axes[0].set_yticks([])
183
+
184
+ # Mask
185
+ _apply_dark_style(axes[1])
186
+ axes[1].imshow(mask_slice, cmap="Reds")
187
+ axes[1].set_title("Mask", color=STYLE["text_color"], fontsize=STYLE["label_size"], fontweight="bold")
188
+ axes[1].set_xticks([])
189
+ axes[1].set_yticks([])
190
+
191
+ # Overlay
192
+ _apply_dark_style(axes[2])
193
+ axes[2].imshow(vol_norm, cmap="gray")
194
+ mask_overlay = np.ma.masked_where(mask_slice == 0, mask_slice)
195
+ axes[2].imshow(mask_overlay, cmap="Reds", alpha=alpha)
196
+ axes[2].set_title("Overlay", color=STYLE["text_color"], fontsize=STYLE["label_size"], fontweight="bold")
197
+ axes[2].set_xticks([])
198
+ axes[2].set_yticks([])
199
+
200
+ fig.suptitle(f"{title} (z={slice_idx})", color=STYLE["text_color"],
201
+ fontsize=STYLE["title_size"], fontweight="bold")
202
+ plt.tight_layout()
203
+ _save_or_show(fig, save_path)
204
+
205
+
206
+ def plot_registration_result(
207
+ ref_volume: np.ndarray,
208
+ moving_volume: np.ndarray,
209
+ transformed_volume: np.ndarray,
210
+ transform: dict,
211
+ save_path: Optional[Union[str, Path]] = None,
212
+ slice_idx: Optional[int] = None
213
+ ) -> None:
214
+ """
215
+ Plot registration result showing before/after alignment.
216
+
217
+ Parameters
218
+ ----------
219
+ ref_volume : ndarray
220
+ Reference volume (Z, Y, X)
221
+ moving_volume : ndarray
222
+ Original moving volume before registration
223
+ transformed_volume : ndarray
224
+ Moving volume after registration
225
+ transform : dict
226
+ Transform parameters from estimate_registration
227
+ save_path : str or Path, optional
228
+ Path to save figure
229
+ slice_idx : int, optional
230
+ Z-slice to display, None uses middle
231
+ """
232
+ if slice_idx is None:
233
+ slice_idx = ref_volume.shape[0] // 2
234
+
235
+ fig, axes = plt.subplots(2, 3, figsize=(12, 8), facecolor=STYLE["facecolor"])
236
+
237
+ slices = [
238
+ ref_volume[slice_idx],
239
+ moving_volume[slice_idx],
240
+ transformed_volume[slice_idx]
241
+ ]
242
+ titles = ["Reference", "Moving (before)", "Moving (after)"]
243
+
244
+ # Top row: individual images
245
+ for ax, s, t in zip(axes[0], slices, titles):
246
+ _apply_dark_style(ax)
247
+ vmin, vmax = np.percentile(s, [1, 99])
248
+ ax.imshow(s, cmap="gray", vmin=vmin, vmax=vmax)
249
+ ax.set_title(t, color=STYLE["text_color"], fontsize=STYLE["label_size"], fontweight="bold")
250
+ ax.set_xticks([])
251
+ ax.set_yticks([])
252
+
253
+ # Bottom row: overlays (ref=magenta, moving=green)
254
+ def create_overlay(ref, mov):
255
+ vmin_r, vmax_r = np.percentile(ref, [1, 99])
256
+ vmin_m, vmax_m = np.percentile(mov, [1, 99])
257
+ ref_norm = np.clip((ref - vmin_r) / (vmax_r - vmin_r + 1e-8), 0, 1)
258
+ mov_norm = np.clip((mov - vmin_m) / (vmax_m - vmin_m + 1e-8), 0, 1)
259
+ overlay = np.stack([ref_norm, mov_norm, ref_norm], axis=-1)
260
+ return overlay
261
+
262
+ # Before registration
263
+ _apply_dark_style(axes[1, 0])
264
+ overlay_before = create_overlay(slices[0], slices[1])
265
+ axes[1, 0].imshow(overlay_before)
266
+ axes[1, 0].set_title("Before (ref=magenta, mov=green)", color=STYLE["text_color"],
267
+ fontsize=STYLE["label_size"], fontweight="bold")
268
+ axes[1, 0].set_xticks([])
269
+ axes[1, 0].set_yticks([])
270
+
271
+ # After registration
272
+ _apply_dark_style(axes[1, 1])
273
+ overlay_after = create_overlay(slices[0], slices[2])
274
+ axes[1, 1].imshow(overlay_after)
275
+ axes[1, 1].set_title("After (ref=magenta, mov=green)", color=STYLE["text_color"],
276
+ fontsize=STYLE["label_size"], fontweight="bold")
277
+ axes[1, 1].set_xticks([])
278
+ axes[1, 1].set_yticks([])
279
+
280
+ # Transform info
281
+ _apply_dark_style(axes[1, 2])
282
+ axes[1, 2].axis("off")
283
+ info_text = (
284
+ f"Transform Parameters\n"
285
+ f"─────────────────────\n"
286
+ f"X offset: {transform.get('x_offset', 0):.2f} px\n"
287
+ f"Y offset: {transform.get('y_offset', 0):.2f} px\n"
288
+ f"Z offset: {transform.get('z_offset', 0):.2f} px\n"
289
+ f"Rotation: {transform.get('rotation', 0):.2f} deg\n"
290
+ f"Correlation: {transform.get('correlation', 0):.4f}\n"
291
+ f"Method: {transform.get('method', 'unknown')}"
292
+ )
293
+ axes[1, 2].text(0.5, 0.5, info_text, transform=axes[1, 2].transAxes,
294
+ fontsize=STYLE["font_size"], color=STYLE["text_color"],
295
+ ha="center", va="center", family="monospace")
296
+
297
+ fig.suptitle(f"Registration Result (z={slice_idx})", color=STYLE["text_color"],
298
+ fontsize=STYLE["title_size"], fontweight="bold")
299
+ plt.tight_layout()
300
+ _save_or_show(fig, save_path)
301
+
302
+
303
+ def plot_intensity_correction(
304
+ ref_volume: np.ndarray,
305
+ moving_volume: np.ndarray,
306
+ corrected_volume: np.ndarray,
307
+ correction: dict,
308
+ save_path: Optional[Union[str, Path]] = None
309
+ ) -> None:
310
+ """
311
+ Plot intensity correction result with histograms.
312
+
313
+ Parameters
314
+ ----------
315
+ ref_volume : ndarray
316
+ Reference volume
317
+ moving_volume : ndarray
318
+ Moving volume before correction
319
+ corrected_volume : ndarray
320
+ Moving volume after correction
321
+ correction : dict
322
+ Correction parameters from estimate_intensity_correction
323
+ save_path : str or Path, optional
324
+ Path to save figure
325
+ """
326
+ fig, axes = plt.subplots(1, 3, figsize=(12, 4), facecolor=STYLE["facecolor"])
327
+
328
+ # Sample for histograms
329
+ ref_sample = ref_volume.ravel()[::100]
330
+ mov_sample = moving_volume.ravel()[::100]
331
+ corr_sample = corrected_volume.ravel()[::100]
332
+
333
+ bins = np.linspace(0, np.percentile(ref_sample, 99.5), 100)
334
+
335
+ _apply_dark_style(axes[0], fig)
336
+ axes[0].hist(ref_sample, bins=bins, alpha=0.7, color="#3498db", label="Reference")
337
+ axes[0].hist(mov_sample, bins=bins, alpha=0.7, color="#e74c3c", label="Moving")
338
+ axes[0].set_xlabel("Intensity", color=STYLE["text_color"], fontsize=STYLE["label_size"])
339
+ axes[0].set_ylabel("Count", color=STYLE["text_color"], fontsize=STYLE["label_size"])
340
+ axes[0].set_title("Before Correction", color=STYLE["text_color"],
341
+ fontsize=STYLE["label_size"], fontweight="bold")
342
+ axes[0].legend(facecolor=STYLE["facecolor"], edgecolor=STYLE["text_color"],
343
+ labelcolor=STYLE["text_color"], fontsize=8)
344
+
345
+ _apply_dark_style(axes[1], fig)
346
+ axes[1].hist(ref_sample, bins=bins, alpha=0.7, color="#3498db", label="Reference")
347
+ axes[1].hist(corr_sample, bins=bins, alpha=0.7, color="#2ecc71", label="Corrected")
348
+ axes[1].set_xlabel("Intensity", color=STYLE["text_color"], fontsize=STYLE["label_size"])
349
+ axes[1].set_ylabel("Count", color=STYLE["text_color"], fontsize=STYLE["label_size"])
350
+ axes[1].set_title("After Correction", color=STYLE["text_color"],
351
+ fontsize=STYLE["label_size"], fontweight="bold")
352
+ axes[1].legend(facecolor=STYLE["facecolor"], edgecolor=STYLE["text_color"],
353
+ labelcolor=STYLE["text_color"], fontsize=8)
354
+
355
+ # Correction info
356
+ _apply_dark_style(axes[2], fig)
357
+ axes[2].axis("off")
358
+ info_text = (
359
+ f"Intensity Correction\n"
360
+ f"─────────────────────\n"
361
+ f"Factor: {correction.get('factor', 1.0):.4f}\n"
362
+ f"Operation: {correction.get('operation', 'multiply')}\n"
363
+ f"Ref background: {correction.get('ref_background', 0):.1f}\n"
364
+ f"Mov background: {correction.get('moving_background', 0):.1f}\n"
365
+ f"Overlap corr: {correction.get('overlap_correlation', 0):.4f}\n"
366
+ f"Method: {correction.get('method', 'unknown')}"
367
+ )
368
+ axes[2].text(0.5, 0.5, info_text, transform=axes[2].transAxes,
369
+ fontsize=STYLE["font_size"], color=STYLE["text_color"],
370
+ ha="center", va="center", family="monospace")
371
+
372
+ fig.suptitle("Intensity Correction", color=STYLE["text_color"],
373
+ fontsize=STYLE["title_size"], fontweight="bold")
374
+ plt.tight_layout()
375
+ _save_or_show(fig, save_path)
376
+
377
+
378
+ def plot_fusion_result(
379
+ view1: np.ndarray,
380
+ view2: np.ndarray,
381
+ fused: np.ndarray,
382
+ save_path: Optional[Union[str, Path]] = None,
383
+ slice_idx: Optional[int] = None,
384
+ title: str = "Fusion Result"
385
+ ) -> None:
386
+ """
387
+ Plot fusion result comparing input views and fused output.
388
+
389
+ Parameters
390
+ ----------
391
+ view1 : ndarray
392
+ First input view (Z, Y, X)
393
+ view2 : ndarray
394
+ Second input view (Z, Y, X)
395
+ fused : ndarray
396
+ Fused output (Z, Y, X)
397
+ save_path : str or Path, optional
398
+ Path to save figure
399
+ slice_idx : int, optional
400
+ Z-slice to display, None uses middle
401
+ title : str, default='Fusion Result'
402
+ Figure title
403
+ """
404
+ if slice_idx is None:
405
+ slice_idx = view1.shape[0] // 2
406
+
407
+ fig, axes = plt.subplots(2, 2, figsize=(10, 10), facecolor=STYLE["facecolor"])
408
+
409
+ slices = [view1[slice_idx], view2[slice_idx], fused[slice_idx]]
410
+ titles = ["View 1", "View 2", "Fused"]
411
+
412
+ for ax, s, t in zip(axes.flat[:3], slices, titles):
413
+ _apply_dark_style(ax)
414
+ vmin, vmax = np.percentile(s, [1, 99])
415
+ ax.imshow(s, cmap="gray", vmin=vmin, vmax=vmax)
416
+ ax.set_title(t, color=STYLE["text_color"], fontsize=STYLE["label_size"], fontweight="bold")
417
+ ax.set_xticks([])
418
+ ax.set_yticks([])
419
+
420
+ # Difference/comparison
421
+ _apply_dark_style(axes[1, 1])
422
+ # Show fused with color indicating which view dominates
423
+ v1_norm = (view1[slice_idx] - view1[slice_idx].min()) / (view1[slice_idx].max() - view1[slice_idx].min() + 1e-8)
424
+ v2_norm = (view2[slice_idx] - view2[slice_idx].min()) / (view2[slice_idx].max() - view2[slice_idx].min() + 1e-8)
425
+ f_norm = (fused[slice_idx] - fused[slice_idx].min()) / (fused[slice_idx].max() - fused[slice_idx].min() + 1e-8)
426
+ # RGB: R=view1 contribution, G=fused, B=view2 contribution
427
+ comparison = np.stack([v1_norm, f_norm, v2_norm], axis=-1)
428
+ axes[1, 1].imshow(comparison)
429
+ axes[1, 1].set_title("Comparison (R=V1, G=fused, B=V2)", color=STYLE["text_color"],
430
+ fontsize=STYLE["label_size"], fontweight="bold")
431
+ axes[1, 1].set_xticks([])
432
+ axes[1, 1].set_yticks([])
433
+
434
+ fig.suptitle(f"{title} (z={slice_idx})", color=STYLE["text_color"],
435
+ fontsize=STYLE["title_size"], fontweight="bold")
436
+ plt.tight_layout()
437
+ _save_or_show(fig, save_path)
438
+
439
+
440
+ def plot_blending_weights(
441
+ mask: np.ndarray,
442
+ blending_range: int,
443
+ save_path: Optional[Union[str, Path]] = None,
444
+ slice_idx: Optional[int] = None
445
+ ) -> None:
446
+ """
447
+ Plot blending weight map from transition mask.
448
+
449
+ Parameters
450
+ ----------
451
+ mask : ndarray
452
+ Transition mask (X, Z) with Y-indices
453
+ blending_range : int
454
+ Blending zone width in pixels
455
+ save_path : str or Path, optional
456
+ Path to save figure
457
+ slice_idx : int, optional
458
+ Z-slice for 1D profile, None uses middle
459
+ """
460
+ if slice_idx is None:
461
+ slice_idx = mask.shape[1] // 2
462
+
463
+ fig, axes = plt.subplots(1, 2, figsize=(12, 5), facecolor=STYLE["facecolor"])
464
+
465
+ # 2D transition plane visualization
466
+ _apply_dark_style(axes[0], fig)
467
+ im = axes[0].imshow(mask.T, cmap="viridis", aspect="auto")
468
+ axes[0].set_xlabel("X", color=STYLE["text_color"], fontsize=STYLE["label_size"])
469
+ axes[0].set_ylabel("Z", color=STYLE["text_color"], fontsize=STYLE["label_size"])
470
+ axes[0].set_title("Transition Plane (Y-index)", color=STYLE["text_color"],
471
+ fontsize=STYLE["label_size"], fontweight="bold")
472
+ cbar = plt.colorbar(im, ax=axes[0])
473
+ cbar.ax.yaxis.set_tick_params(color=STYLE["text_color"])
474
+ cbar.ax.yaxis.set_ticklabels(cbar.ax.yaxis.get_ticklabels(), color=STYLE["text_color"])
475
+
476
+ # 1D profile at selected z
477
+ _apply_dark_style(axes[1], fig)
478
+ profile = mask[:, slice_idx]
479
+ x = np.arange(len(profile))
480
+ axes[1].plot(x, profile, color="#3498db", linewidth=2, label="Transition Y")
481
+ axes[1].fill_between(x, profile - blending_range, profile + blending_range,
482
+ alpha=0.3, color="#3498db", label=f"Blend zone (±{blending_range})")
483
+ axes[1].set_xlabel("X", color=STYLE["text_color"], fontsize=STYLE["label_size"])
484
+ axes[1].set_ylabel("Y (transition)", color=STYLE["text_color"], fontsize=STYLE["label_size"])
485
+ axes[1].set_title(f"Profile at z={slice_idx}", color=STYLE["text_color"],
486
+ fontsize=STYLE["label_size"], fontweight="bold")
487
+ axes[1].legend(facecolor=STYLE["facecolor"], edgecolor=STYLE["text_color"],
488
+ labelcolor=STYLE["text_color"], fontsize=8)
489
+ axes[1].grid(True, alpha=STYLE["grid_alpha"], color=STYLE["grid_color"])
490
+
491
+ fig.suptitle("Blending Weights", color=STYLE["text_color"],
492
+ fontsize=STYLE["title_size"], fontweight="bold")
493
+ plt.tight_layout()
494
+ _save_or_show(fig, save_path)
495
+
496
+
497
+ def plot_temporal_parameters(
498
+ timepoints: np.ndarray,
499
+ parameters: dict,
500
+ smoothed_parameters: Optional[dict] = None,
501
+ save_path: Optional[Union[str, Path]] = None,
502
+ title: str = "Temporal Parameters"
503
+ ) -> None:
504
+ """
505
+ Plot registration parameters over time.
506
+
507
+ Parameters
508
+ ----------
509
+ timepoints : ndarray
510
+ Time point indices
511
+ parameters : dict
512
+ Raw parameters (x_offset, y_offset, rotation, etc.)
513
+ smoothed_parameters : dict, optional
514
+ Smoothed parameters for comparison
515
+ save_path : str or Path, optional
516
+ Path to save figure
517
+ title : str, default='Temporal Parameters'
518
+ Figure title
519
+ """
520
+ fig, axes = plt.subplots(2, 2, figsize=(12, 8), facecolor=STYLE["facecolor"])
521
+
522
+ param_names = ["x_offset", "y_offset", "rotation", "intensity_factor"]
523
+ ylabels = ["X Offset (px)", "Y Offset (px)", "Rotation (deg)", "Intensity Factor"]
524
+ colors = ["#3498db", "#2ecc71", "#e74c3c", "#f39c12"]
525
+
526
+ for ax, param, ylabel, color in zip(axes.flat, param_names, ylabels, colors):
527
+ _apply_dark_style(ax, fig)
528
+
529
+ if param in parameters:
530
+ ax.plot(timepoints, parameters[param], "o-", color=color, alpha=0.6,
531
+ markersize=3, linewidth=1, label="Raw")
532
+
533
+ if smoothed_parameters and param in smoothed_parameters:
534
+ ax.plot(timepoints, smoothed_parameters[param], "-", color=color,
535
+ linewidth=2, label="Smoothed")
536
+
537
+ ax.set_xlabel("Timepoint", color=STYLE["text_color"], fontsize=STYLE["label_size"])
538
+ ax.set_ylabel(ylabel, color=STYLE["text_color"], fontsize=STYLE["label_size"])
539
+ ax.grid(True, alpha=STYLE["grid_alpha"], color=STYLE["grid_color"])
540
+
541
+ if smoothed_parameters:
542
+ ax.legend(facecolor=STYLE["facecolor"], edgecolor=STYLE["text_color"],
543
+ labelcolor=STYLE["text_color"], fontsize=8)
544
+
545
+ fig.suptitle(title, color=STYLE["text_color"], fontsize=STYLE["title_size"], fontweight="bold")
546
+ plt.tight_layout()
547
+ _save_or_show(fig, save_path)
548
+
549
+
550
+ def plot_correlation_heatmap(
551
+ correlations: np.ndarray,
552
+ timepoints: np.ndarray,
553
+ save_path: Optional[Union[str, Path]] = None,
554
+ title: str = "Registration Correlation"
555
+ ) -> None:
556
+ """
557
+ Plot registration correlation quality over time.
558
+
559
+ Parameters
560
+ ----------
561
+ correlations : ndarray
562
+ Correlation values per timepoint
563
+ timepoints : ndarray
564
+ Time point indices
565
+ save_path : str or Path, optional
566
+ Path to save figure
567
+ title : str, default='Registration Correlation'
568
+ Figure title
569
+ """
570
+ fig, axes = plt.subplots(1, 2, figsize=(12, 4), facecolor=STYLE["facecolor"])
571
+
572
+ # Line plot
573
+ _apply_dark_style(axes[0], fig)
574
+ axes[0].plot(timepoints, correlations, "o-", color="#3498db", markersize=4, linewidth=1.5)
575
+ axes[0].axhline(np.mean(correlations), color="#e74c3c", linestyle="--",
576
+ linewidth=1.5, label=f"Mean: {np.mean(correlations):.4f}")
577
+ axes[0].set_xlabel("Timepoint", color=STYLE["text_color"], fontsize=STYLE["label_size"])
578
+ axes[0].set_ylabel("Correlation", color=STYLE["text_color"], fontsize=STYLE["label_size"])
579
+ axes[0].set_title("Correlation Over Time", color=STYLE["text_color"],
580
+ fontsize=STYLE["label_size"], fontweight="bold")
581
+ axes[0].grid(True, alpha=STYLE["grid_alpha"], color=STYLE["grid_color"])
582
+ axes[0].legend(facecolor=STYLE["facecolor"], edgecolor=STYLE["text_color"],
583
+ labelcolor=STYLE["text_color"], fontsize=8)
584
+
585
+ # Histogram
586
+ _apply_dark_style(axes[1], fig)
587
+ axes[1].hist(correlations, bins=30, color="#3498db", alpha=0.8, edgecolor="none")
588
+ axes[1].axvline(np.mean(correlations), color="#e74c3c", linestyle="--",
589
+ linewidth=1.5, label=f"Mean: {np.mean(correlations):.4f}")
590
+ axes[1].set_xlabel("Correlation", color=STYLE["text_color"], fontsize=STYLE["label_size"])
591
+ axes[1].set_ylabel("Count", color=STYLE["text_color"], fontsize=STYLE["label_size"])
592
+ axes[1].set_title("Correlation Distribution", color=STYLE["text_color"],
593
+ fontsize=STYLE["label_size"], fontweight="bold")
594
+ axes[1].legend(facecolor=STYLE["facecolor"], edgecolor=STYLE["text_color"],
595
+ labelcolor=STYLE["text_color"], fontsize=8)
596
+
597
+ fig.suptitle(title, color=STYLE["text_color"], fontsize=STYLE["title_size"], fontweight="bold")
598
+ plt.tight_layout()
599
+ _save_or_show(fig, save_path)
600
+
601
+
602
+ def plot_volume_overview(
603
+ volume: np.ndarray,
604
+ save_path: Optional[Union[str, Path]] = None,
605
+ title: str = "Volume Overview",
606
+ n_slices: int = 9
607
+ ) -> None:
608
+ """
609
+ Plot grid of evenly spaced Z-slices through volume.
610
+
611
+ Parameters
612
+ ----------
613
+ volume : ndarray
614
+ 3D volume (Z, Y, X)
615
+ save_path : str or Path, optional
616
+ Path to save figure
617
+ title : str, default='Volume Overview'
618
+ Figure title
619
+ n_slices : int, default=9
620
+ Number of slices to display (must be square-rootable)
621
+ """
622
+ n_rows = int(np.sqrt(n_slices))
623
+ n_cols = int(np.ceil(n_slices / n_rows))
624
+
625
+ Z = volume.shape[0]
626
+ slice_indices = np.linspace(0, Z - 1, n_slices, dtype=int)
627
+
628
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows),
629
+ facecolor=STYLE["facecolor"])
630
+
631
+ vmin, vmax = np.percentile(volume, [1, 99])
632
+
633
+ for idx, (ax, z) in enumerate(zip(axes.flat, slice_indices)):
634
+ _apply_dark_style(ax)
635
+ ax.imshow(volume[z], cmap="gray", vmin=vmin, vmax=vmax)
636
+ ax.set_title(f"z={z}", color=STYLE["text_color"], fontsize=STYLE["font_size"])
637
+ ax.set_xticks([])
638
+ ax.set_yticks([])
639
+
640
+ # Hide extra axes
641
+ for ax in axes.flat[n_slices:]:
642
+ ax.axis("off")
643
+
644
+ fig.suptitle(title, color=STYLE["text_color"], fontsize=STYLE["title_size"], fontweight="bold")
645
+ plt.tight_layout()
646
+ _save_or_show(fig, save_path)
647
+
648
+
649
+ def save_fusion_diagnostics(
650
+ output_dir: Union[str, Path],
651
+ ref_volume: np.ndarray,
652
+ moving_volume: np.ndarray,
653
+ transformed_volume: np.ndarray,
654
+ fused_volume: np.ndarray,
655
+ transform: dict,
656
+ correction: dict,
657
+ mask: np.ndarray,
658
+ blending_range: int,
659
+ prefix: str = ""
660
+ ) -> None:
661
+ """
662
+ Save complete set of fusion diagnostic images.
663
+
664
+ Parameters
665
+ ----------
666
+ output_dir : str or Path
667
+ Directory to save images
668
+ ref_volume : ndarray
669
+ Reference volume
670
+ moving_volume : ndarray
671
+ Original moving volume
672
+ transformed_volume : ndarray
673
+ Transformed moving volume
674
+ fused_volume : ndarray
675
+ Final fused result
676
+ transform : dict
677
+ Registration transform parameters
678
+ correction : dict
679
+ Intensity correction parameters
680
+ mask : ndarray
681
+ Blending transition mask
682
+ blending_range : int
683
+ Blending zone width
684
+ prefix : str, default=''
685
+ Prefix for filenames
686
+ """
687
+ output_dir = Path(output_dir)
688
+ output_dir.mkdir(parents=True, exist_ok=True)
689
+
690
+ pre = f"{prefix}_" if prefix else ""
691
+
692
+ # Registration
693
+ plot_registration_result(
694
+ ref_volume, moving_volume, transformed_volume, transform,
695
+ save_path=output_dir / f"{pre}registration.png"
696
+ )
697
+
698
+ # Intensity correction
699
+ corrected_volume = transformed_volume.copy() # Apply correction inline if needed
700
+ plot_intensity_correction(
701
+ ref_volume, transformed_volume, corrected_volume, correction,
702
+ save_path=output_dir / f"{pre}intensity_correction.png"
703
+ )
704
+
705
+ # Fusion result
706
+ plot_fusion_result(
707
+ ref_volume, transformed_volume, fused_volume,
708
+ save_path=output_dir / f"{pre}fusion_result.png"
709
+ )
710
+
711
+ # Blending weights
712
+ plot_blending_weights(
713
+ mask, blending_range,
714
+ save_path=output_dir / f"{pre}blending_weights.png"
715
+ )
716
+
717
+ # Volume overview
718
+ plot_volume_overview(
719
+ fused_volume,
720
+ save_path=output_dir / f"{pre}fused_overview.png"
721
+ )
722
+
723
+ print(f"Saved fusion diagnostics to {output_dir}")