pytractoviz 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pytractoviz/viz.py ADDED
@@ -0,0 +1,4272 @@
1
+ """Visualization module for diffusion tractography."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import contextlib
6
+ import gc
7
+ import logging
8
+ import multiprocessing
9
+ import os
10
+ import resource
11
+ import sys
12
+ import tempfile
13
+ import tracemalloc
14
+ from concurrent.futures import ProcessPoolExecutor, as_completed
15
+ from pathlib import Path
16
+ from typing import Any
17
+
18
+ try:
19
+ import psutil
20
+ except ImportError:
21
+ psutil = None
22
+
23
+ import imageio
24
+ import matplotlib.pyplot as plt
25
+ import nibabel as nib
26
+ import numpy as np
27
+ import vtk
28
+ from dipy.io.stateful_tractogram import Space, StatefulTractogram
29
+ from dipy.io.streamline import load_trk
30
+ from dipy.segment.bundles import bundle_shape_similarity
31
+ from dipy.segment.clustering import QuickBundles
32
+ from dipy.segment.featurespeed import ResampleFeature
33
+ from dipy.segment.metricspeed import AveragePointwiseEuclideanMetric
34
+ from dipy.stats.analysis import afq_profile, assignment_map, gaussian_weights
35
+ from dipy.tracking.streamline import (
36
+ Streamlines,
37
+ cluster_confidence,
38
+ orient_by_streamline,
39
+ transform_streamlines,
40
+ )
41
+ from dipy.tracking.utils import length
42
+ from fury import actor, window
43
+ from fury.colormap import create_colormap
44
+ from PIL import Image
45
+ from scipy.spatial.transform import Rotation
46
+ from xvfbwrapper import Xvfb
47
+
48
+ from pytractoviz.html import create_quality_check_html
49
+ from pytractoviz.utils import (
50
+ ANATOMICAL_VIEW_ANGLES,
51
+ calculate_bbox_size,
52
+ calculate_centroid,
53
+ calculate_combined_bbox_size,
54
+ calculate_combined_centroid,
55
+ calculate_direction_colors,
56
+ set_anatomical_camera,
57
+ )
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+ # Constants for memory estimation and warnings
62
+ MAX_POINTS_PER_STREAMLINE_FRAGMENTATION_THRESHOLD = 10000 # Streamlines with >10k points risk fragmentation
63
+ DENSE_TRACT_POINT_THRESHOLD = 1000000 # >1M total points triggers dense tract warning
64
+ VERY_LONG_STREAMLINE_POINT_THRESHOLD = 50000 # Streamlines with >50k points trigger warning
65
+
66
+
67
+ def _log_memory_usage(
68
+ label: str = "",
69
+ *,
70
+ enable_tracemalloc: bool = False,
71
+ log_level: int = logging.DEBUG,
72
+ ) -> dict[str, float | int] | None:
73
+ """Log current memory usage for debugging memory issues.
74
+
75
+ This function provides comprehensive memory monitoring using:
76
+ - psutil (if available): Process-level memory (RSS, VMS)
77
+ - tracemalloc (built-in): Python-level memory tracking
78
+
79
+ Parameters
80
+ ----------
81
+ label : str, optional
82
+ Label to identify this memory checkpoint (e.g., "after loading tract").
83
+ enable_tracemalloc : bool, default=False
84
+ If True, also log top memory allocations from tracemalloc.
85
+ Note: tracemalloc must be started with tracemalloc.start() first.
86
+ log_level : int, default=logging.DEBUG
87
+ Logging level to use for memory information.
88
+
89
+ Returns
90
+ -------
91
+ dict[str, float | int] | None
92
+ Dictionary with memory statistics, or None if psutil is not available.
93
+ Keys: 'rss_mb', 'vms_mb', 'percent', 'available_mb' (if psutil available).
94
+
95
+ Examples
96
+ --------
97
+ >>> # Basic usage
98
+ >>> _log_memory_usage("Before processing")
99
+ >>> # Process data...
100
+ >>> _log_memory_usage("After processing")
101
+
102
+ >>> # With tracemalloc for detailed tracking
103
+ >>> import tracemalloc
104
+ >>> tracemalloc.start()
105
+ >>> _log_memory_usage("Checkpoint 1", enable_tracemalloc=True)
106
+ """
107
+ memory_info: dict[str, float | int] = {}
108
+
109
+ # Method 1: resource module (built-in, Unix/macOS/Linux, no dependencies)
110
+ try:
111
+ usage = resource.getrusage(resource.RUSAGE_SELF)
112
+ # ru_maxrss is in KB on macOS, MB on Linux
113
+ if sys.platform == "darwin": # macOS
114
+ resource_rss_mb = usage.ru_maxrss / 1024
115
+ else: # Linux
116
+ resource_rss_mb = usage.ru_maxrss
117
+ memory_info["resource_rss_mb"] = round(resource_rss_mb, 2)
118
+ logger.log(
119
+ log_level,
120
+ "Memory usage%s (resource): RSS=%.2f MB",
121
+ f" [{label}]" if label else "",
122
+ resource_rss_mb,
123
+ )
124
+ except (OSError, AttributeError):
125
+ # resource module not available on this platform
126
+ pass
127
+
128
+ # Method 2: Process-level memory using psutil (if available, more detailed)
129
+ if psutil is not None:
130
+ process = psutil.Process(os.getpid())
131
+ mem_info = process.memory_info()
132
+ mem_percent = process.memory_percent()
133
+
134
+ # Get system memory info
135
+ try:
136
+ sys_mem = psutil.virtual_memory()
137
+ available_mb = sys_mem.available / (1024**2)
138
+ except (OSError, RuntimeError, AttributeError):
139
+ available_mb = 0.0
140
+
141
+ rss_mb = mem_info.rss / (1024**2) # Resident Set Size
142
+ vms_mb = mem_info.vms / (1024**2) # Virtual Memory Size
143
+
144
+ memory_info = {
145
+ "rss_mb": round(rss_mb, 2),
146
+ "vms_mb": round(vms_mb, 2),
147
+ "percent": round(mem_percent, 2),
148
+ "available_mb": round(available_mb, 2),
149
+ }
150
+
151
+ logger.log(
152
+ log_level,
153
+ "Memory usage%s: RSS=%.2f MB, VMS=%.2f MB, Process=%.2f%%, System Available=%.2f MB",
154
+ f" [{label}]" if label else "",
155
+ rss_mb,
156
+ vms_mb,
157
+ mem_percent,
158
+ available_mb,
159
+ )
160
+ else:
161
+ logger.log(
162
+ log_level,
163
+ "Memory monitoring: psutil not available. Install with: pip install psutil",
164
+ )
165
+
166
+ # Python-level memory using tracemalloc (if enabled)
167
+ if enable_tracemalloc and tracemalloc.is_tracing():
168
+ snapshot = tracemalloc.take_snapshot()
169
+ top_stats = snapshot.statistics("lineno")
170
+
171
+ logger.log(log_level, "Top 5 memory allocations%s:", f" [{label}]" if label else "")
172
+ for index, stat in enumerate(top_stats[:5], 1):
173
+ logger.log(
174
+ log_level,
175
+ " #%d: %s: %.1f MB",
176
+ index,
177
+ stat.traceback[0],
178
+ stat.size / (1024**2),
179
+ )
180
+
181
+ # Get current and peak memory
182
+ current, peak = tracemalloc.get_traced_memory()
183
+ logger.log(
184
+ log_level,
185
+ "Tracemalloc: Current=%.2f MB, Peak=%.2f MB",
186
+ current / (1024**2),
187
+ peak / (1024**2),
188
+ )
189
+
190
+ memory_info["tracemalloc_current_mb"] = round(current / (1024**2), 2)
191
+ memory_info["tracemalloc_peak_mb"] = round(peak / (1024**2), 2)
192
+
193
+ return memory_info if memory_info else None
194
+
195
+
196
+ def _set_memory_limit(memory_limit_mb: float | None = None) -> None:
197
+ """Set a hard memory limit for the current process.
198
+
199
+ This uses resource.setrlimit() to set a maximum virtual memory (address space)
200
+ limit. If the process exceeds this limit, it will be killed by the OS.
201
+
202
+ Parameters
203
+ ----------
204
+ memory_limit_mb : float | None, optional
205
+ Maximum memory limit in MB. If None, no limit is set.
206
+ If set, the process will be killed if it exceeds this limit.
207
+
208
+ Examples
209
+ --------
210
+ >>> # Limit to 8 GB
211
+ >>> _set_memory_limit(8192)
212
+
213
+ >>> # Limit to 4 GB
214
+ >>> _set_memory_limit(4096)
215
+
216
+ Notes
217
+ -----
218
+ - This sets RLIMIT_AS (virtual memory/address space limit)
219
+ - On macOS, this is enforced by the kernel
220
+ - The limit applies to the entire process, including all threads
221
+ - Once set, the limit cannot be increased (only decreased)
222
+ """
223
+ if memory_limit_mb is None:
224
+ return
225
+
226
+ try:
227
+ # Convert MB to bytes
228
+ memory_limit_bytes = int(memory_limit_mb * 1024 * 1024)
229
+
230
+ # Set virtual memory limit (RLIMIT_AS)
231
+ # This limits the total address space the process can use
232
+ resource.setrlimit(resource.RLIMIT_AS, (memory_limit_bytes, resource.RLIM_INFINITY))
233
+
234
+ logger.info("Memory limit set to %.2f MB (%.2f GB)", memory_limit_mb, memory_limit_mb / 1024)
235
+ except (OSError, ValueError) as e:
236
+ logger.warning("Failed to set memory limit: %s", e)
237
+
238
+
239
+ def _estimate_actor_memory_mb(streamlines: Streamlines, figure_size: tuple[int, int] = (800, 800)) -> float:
240
+ """Estimate memory needed for creating VTK actors and rendering.
241
+
242
+ Parameters
243
+ ----------
244
+ streamlines : Streamlines
245
+ The streamlines to visualize.
246
+ figure_size : tuple[int, int], default=(800, 800)
247
+ Size of output image in pixels.
248
+
249
+ Returns
250
+ -------
251
+ float
252
+ Estimated memory in MB needed for actor creation and rendering.
253
+ """
254
+ # Count total points
255
+ total_points = sum(len(sl) for sl in streamlines)
256
+ num_streamlines = len(streamlines)
257
+
258
+ # Calculate average points per streamline
259
+ avg_points_per_sl = total_points / num_streamlines if num_streamlines > 0 else 0
260
+ max_points_per_sl = max(len(sl) for sl in streamlines) if streamlines else 0
261
+
262
+ # Estimate memory:
263
+ # - VTK actor data: ~200 bytes per point (coordinates, colors, normals, etc.)
264
+ # (increased from 100 to account for VTK's internal structures)
265
+ # - Scene rendering buffer: image_size * 4 bytes (RGBA) * 3 (triple buffering for safety)
266
+ # - Additional overhead: ~30% for VTK internal structures and fragmentation
267
+ # - Large contiguous allocation penalty: if max streamline is very long, VTK may need
268
+ # a large contiguous buffer, which can fail even with available memory due to fragmentation
269
+
270
+ actor_memory_mb = (total_points * 200) / (1024**2) # Actor data (increased estimate)
271
+ render_memory_mb = (figure_size[0] * figure_size[1] * 4 * 3) / (1024**2) # Render buffer (triple buffering)
272
+
273
+ # Add penalty for very long streamlines (fragmentation risk)
274
+ # If any streamline has >10k points, VTK needs a large contiguous buffer
275
+ fragmentation_penalty_mb = 0.0
276
+ if max_points_per_sl > MAX_POINTS_PER_STREAMLINE_FRAGMENTATION_THRESHOLD:
277
+ # Large contiguous allocation needed - add 50% overhead for fragmentation
278
+ fragmentation_penalty_mb = actor_memory_mb * 0.5
279
+ logger.warning(
280
+ "Very long streamline detected (%d points). VTK may need large contiguous memory allocation.",
281
+ max_points_per_sl,
282
+ )
283
+
284
+ overhead_mb = (actor_memory_mb + render_memory_mb) * 0.3 # 30% overhead
285
+ total_mb = actor_memory_mb + render_memory_mb + overhead_mb + fragmentation_penalty_mb
286
+
287
+ logger.info(
288
+ "Memory estimate for actor: %.2f MB (total_points=%d, streamlines=%d, "
289
+ "avg_points/sl=%.1f, max_points/sl=%d, image=%dx%d)",
290
+ total_mb,
291
+ total_points,
292
+ num_streamlines,
293
+ avg_points_per_sl,
294
+ max_points_per_sl,
295
+ figure_size[0],
296
+ figure_size[1],
297
+ )
298
+
299
+ return total_mb
300
+
301
+
302
+ def _check_memory_available(required_mb: float, safety_margin: float = 0.2) -> bool:
303
+ """Check if enough memory is available before loading data.
304
+
305
+ Parameters
306
+ ----------
307
+ required_mb : float
308
+ Estimated memory required in MB.
309
+ safety_margin : float, default=0.2
310
+ Safety margin as fraction (0.2 = 20% extra buffer).
311
+
312
+ Returns
313
+ -------
314
+ bool
315
+ True if enough memory is available, False otherwise.
316
+ """
317
+ if psutil is None:
318
+ # If psutil not available, assume we have enough
319
+ logger.debug("psutil not available, skipping memory check")
320
+ return True
321
+
322
+ try:
323
+ # Get system memory
324
+ sys_mem = psutil.virtual_memory()
325
+ available_mb = sys_mem.available / (1024**2)
326
+
327
+ # Get current process memory
328
+ process = psutil.Process(os.getpid())
329
+ current_mb = process.memory_info().rss / (1024**2)
330
+
331
+ # Calculate required with safety margin
332
+ required_with_margin = required_mb * (1 + safety_margin)
333
+
334
+ # Check if we have enough available memory
335
+ if available_mb < required_with_margin:
336
+ logger.warning(
337
+ "Insufficient memory: Need %.2f MB (with %.0f%% margin), "
338
+ "but only %.2f MB available. Current process: %.2f MB",
339
+ required_with_margin,
340
+ safety_margin * 100,
341
+ available_mb,
342
+ current_mb,
343
+ )
344
+ return False
345
+
346
+ logger.debug(
347
+ "Memory check: Need %.2f MB, Available: %.2f MB, Current: %.2f MB",
348
+ required_with_margin,
349
+ available_mb,
350
+ current_mb,
351
+ )
352
+ except (OSError, RuntimeError, AttributeError) as e:
353
+ logger.warning("Memory check failed: %s. Proceeding anyway.", e)
354
+ return True
355
+ else:
356
+ return True
357
+
358
+
359
+ def _get_n_jobs_with_memory_limit(
360
+ base_n_jobs: int,
361
+ estimated_memory_per_job_mb: float = 2000.0,
362
+ safety_margin: float = 0.2,
363
+ ) -> int:
364
+ """Calculate n_jobs considering available memory.
365
+
366
+ Reduces n_jobs if there isn't enough memory to run all jobs in parallel.
367
+
368
+ Parameters
369
+ ----------
370
+ base_n_jobs : int
371
+ Base number of jobs to use (from CPU/SLURM settings).
372
+ estimated_memory_per_job_mb : float, default=2000.0
373
+ Estimated memory per job in MB. Default assumes ~2GB per worker.
374
+ safety_margin : float, default=0.2
375
+ Safety margin as fraction (0.2 = 20% extra buffer).
376
+
377
+ Returns
378
+ -------
379
+ int
380
+ Adjusted number of jobs considering memory constraints.
381
+ """
382
+ if psutil is None:
383
+ # If psutil not available, return base_n_jobs
384
+ return base_n_jobs
385
+
386
+ try:
387
+ # Get available system memory
388
+ sys_mem = psutil.virtual_memory()
389
+ available_mb = sys_mem.available / (1024**2)
390
+
391
+ # Get current process memory
392
+ process = psutil.Process(os.getpid())
393
+ current_mb = process.memory_info().rss / (1024**2)
394
+
395
+ # Calculate how much memory we can use for workers
396
+ # Reserve some for the main process
397
+ usable_mb = available_mb - (current_mb * 0.5) # Reserve 50% of current for main process
398
+
399
+ # Calculate required memory with safety margin
400
+ required_per_job = estimated_memory_per_job_mb * (1 + safety_margin)
401
+
402
+ # Calculate max jobs based on memory
403
+ max_jobs_by_memory = max(1, int(usable_mb / required_per_job))
404
+
405
+ # Use the minimum of base_n_jobs and memory-limited jobs
406
+ optimal_jobs = min(base_n_jobs, max_jobs_by_memory)
407
+
408
+ if optimal_jobs < base_n_jobs:
409
+ logger.info(
410
+ "Reduced n_jobs from %d to %d due to memory constraints "
411
+ "(Available: %.2f MB, Estimated per job: %.2f MB)",
412
+ base_n_jobs,
413
+ optimal_jobs,
414
+ usable_mb,
415
+ required_per_job,
416
+ )
417
+ except (OSError, RuntimeError, AttributeError) as e:
418
+ logger.warning("Memory-based n_jobs calculation failed: %s. Using base_n_jobs=%d", e, base_n_jobs)
419
+ return base_n_jobs
420
+ else:
421
+ return optimal_jobs
422
+
423
+
424
+ def _get_optimal_n_jobs() -> int:
425
+ """Calculate optimal number of jobs considering SLURM and OpenMP settings.
426
+
427
+ This function respects SLURM CPU allocations and OpenMP thread settings
428
+ to prevent oversubscription. It checks:
429
+ 1. SLURM_CPUS_PER_TASK (if running under SLURM)
430
+ 2. SLURM_JOB_CPUS_PER_NODE (if running under SLURM)
431
+ 3. OMP_NUM_THREADS (to avoid conflicts with OpenMP)
432
+ 4. Falls back to multiprocessing.cpu_count() if not in SLURM
433
+
434
+ Returns
435
+ -------
436
+ int
437
+ Optimal number of parallel jobs to use.
438
+ """
439
+ # Check if running under SLURM
440
+ slurm_cpus_per_task = os.environ.get("SLURM_CPUS_PER_TASK")
441
+ slurm_job_cpus = os.environ.get("SLURM_JOB_CPUS_PER_NODE")
442
+
443
+ if slurm_cpus_per_task:
444
+ # Use SLURM allocation
445
+ allocated_cpus = int(slurm_cpus_per_task)
446
+ logger.debug("Using SLURM_CPUS_PER_TASK=%d", allocated_cpus)
447
+
448
+ # Check OMP_NUM_THREADS to avoid oversubscription
449
+ omp_threads = int(os.environ.get("OMP_NUM_THREADS", "1"))
450
+ if omp_threads > 1:
451
+ # If OpenMP uses multiple threads, reduce n_jobs accordingly
452
+ # Formula: n_jobs = allocated_cpus / omp_threads
453
+ optimal_jobs = max(1, allocated_cpus // omp_threads)
454
+ logger.debug(
455
+ "OMP_NUM_THREADS=%d detected. Using n_jobs=%d (allocated_cpus=%d / omp_threads=%d)",
456
+ omp_threads,
457
+ optimal_jobs,
458
+ allocated_cpus,
459
+ omp_threads,
460
+ )
461
+ return optimal_jobs
462
+ return allocated_cpus
463
+ if slurm_job_cpus:
464
+ # Use SLURM job CPUs (may be a range like "8-16", take the first value)
465
+ allocated_cpus = int(slurm_job_cpus.split("-")[0].split(",")[0])
466
+ logger.debug("Using SLURM_JOB_CPUS_PER_NODE=%d", allocated_cpus)
467
+
468
+ # Check OMP_NUM_THREADS
469
+ omp_threads = int(os.environ.get("OMP_NUM_THREADS", "1"))
470
+ if omp_threads > 1:
471
+ optimal_jobs = max(1, allocated_cpus // omp_threads)
472
+ logger.debug(
473
+ "OMP_NUM_THREADS=%d detected. Using n_jobs=%d",
474
+ omp_threads,
475
+ optimal_jobs,
476
+ )
477
+ return optimal_jobs
478
+ return allocated_cpus
479
+
480
+ # Not in SLURM, check OMP_NUM_THREADS and fall back to cpu_count()
481
+ omp_threads = int(os.environ.get("OMP_NUM_THREADS", "0"))
482
+ if omp_threads > 0:
483
+ # If OMP_NUM_THREADS is set, use it to calculate n_jobs
484
+ total_cpus = multiprocessing.cpu_count()
485
+ optimal_jobs = max(1, total_cpus // omp_threads)
486
+ logger.debug(
487
+ "OMP_NUM_THREADS=%d detected. Using n_jobs=%d (cpu_count=%d / omp_threads=%d)",
488
+ omp_threads,
489
+ optimal_jobs,
490
+ total_cpus,
491
+ omp_threads,
492
+ )
493
+ return optimal_jobs
494
+
495
+ # Default: use all available CPUs
496
+ total_cpus = multiprocessing.cpu_count()
497
+ logger.debug("Using all available CPUs: n_jobs=%d", total_cpus)
498
+ return total_cpus
499
+
500
+
501
+ def _process_tract_worker(
502
+ subject_id: str,
503
+ tract_name: str,
504
+ tract_file: str | Path,
505
+ subject_ref_img: str | Path,
506
+ tract_output_dir: str | Path,
507
+ subjects_mni_space: dict[str, dict[str, str | Path]] | None,
508
+ atlas_files: dict[str, str | Path] | None,
509
+ metric_files: dict[str, dict[str, str | Path]] | None,
510
+ atlas_ref_img: str | Path | None,
511
+ *,
512
+ flip_lr: bool,
513
+ skip_checks: list[str],
514
+ visualizer_params: dict[str, Any],
515
+ **kwargs: Any,
516
+ ) -> tuple[str, str, dict[str, str | Path]]:
517
+ """Worker function for parallel processing of tracts.
518
+
519
+ This function creates a new visualizer instance and processes a single tract.
520
+ Returns (subject_id, tract_name, results_dict).
521
+ """
522
+ visualizer = None
523
+ results: dict[str, str | Path] = {}
524
+
525
+ try:
526
+ # Initialize VTK offscreen rendering in worker process
527
+ # This is critical for headless cluster environments
528
+ # Set environment variables before any VTK operations
529
+ os.environ.setdefault("VTK_STREAM_READER", "1")
530
+ os.environ.setdefault("VTK_STREAM_WRITER", "1")
531
+ # Force offscreen rendering to prevent segfaults
532
+ os.environ.setdefault("VTK_USE_OFFSCREEN", "1")
533
+ # Disable OpenGL to prevent segfaults in headless environments
534
+ os.environ.setdefault("VTK_USE_OSMESA", "0")
535
+
536
+ # Disable VTK warnings and errors that might cause issues
537
+ with contextlib.suppress(AttributeError, RuntimeError, ImportError):
538
+ vtk.vtkObject.GlobalWarningDisplayOff()
539
+ # Try to set error callback to prevent crashes
540
+ with contextlib.suppress(AttributeError, RuntimeError):
541
+ vtk.vtkOutputWindow.SetGlobalWarningDisplay(0)
542
+
543
+ # Set matplotlib to non-interactive backend for headless environments
544
+ with contextlib.suppress(ImportError, ValueError, RuntimeError):
545
+ # Use Agg backend (no display needed) for headless environments
546
+ plt.switch_backend("Agg")
547
+
548
+ # Create a new visualizer instance in this worker process
549
+ # Wrap in try-except to catch initialization errors that might cause segfaults
550
+ try:
551
+ visualizer = TractographyVisualizer(**visualizer_params)
552
+ except (OSError, RuntimeError, MemoryError) as e:
553
+ logger.exception(
554
+ "Failed to initialize visualizer in worker for %s/%s: %s",
555
+ subject_id,
556
+ tract_name,
557
+ type(e).__name__,
558
+ )
559
+ return (subject_id, tract_name, {})
560
+
561
+ # Process the tract with additional error handling
562
+ try:
563
+ results = visualizer._process_single_tract(
564
+ subject_id=subject_id,
565
+ tract_name=tract_name,
566
+ tract_file=tract_file,
567
+ subject_ref_img=Path(subject_ref_img),
568
+ tract_output_dir=Path(tract_output_dir),
569
+ subjects_mni_space=subjects_mni_space,
570
+ atlas_files=atlas_files,
571
+ metric_files=metric_files,
572
+ atlas_ref_img=atlas_ref_img,
573
+ flip_lr=flip_lr,
574
+ skip_checks=skip_checks,
575
+ **kwargs,
576
+ )
577
+ except MemoryError:
578
+ # Memory errors are critical - log and return empty results
579
+ logger.exception(
580
+ "Memory error in worker process for %s/%s. Consider reducing n_jobs or increasing memory allocation.",
581
+ subject_id,
582
+ tract_name,
583
+ )
584
+ results = {}
585
+ except (OSError, ValueError, RuntimeError) as e:
586
+ # Catch specific exceptions that might cause process crashes
587
+ # Log the error but don't re-raise to prevent process pool from breaking
588
+ logger.exception(
589
+ "Error in worker process for %s/%s (%s)",
590
+ subject_id,
591
+ tract_name,
592
+ type(e).__name__,
593
+ )
594
+ # Return empty results dict on error
595
+ results = {}
596
+ except Exception as e:
597
+ # Catch any other unexpected exceptions
598
+ logger.exception(
599
+ "Unexpected error in worker process for %s/%s (%s)",
600
+ subject_id,
601
+ tract_name,
602
+ type(e).__name__,
603
+ )
604
+ # Return empty results dict on error
605
+ results = {}
606
+ except (OSError, ValueError, RuntimeError, MemoryError) as e:
607
+ # Catch errors during initialization or setup
608
+ logger.exception(
609
+ "Error during worker initialization for %s/%s (%s)",
610
+ subject_id,
611
+ tract_name,
612
+ type(e).__name__,
613
+ )
614
+ results = {}
615
+ except Exception as e:
616
+ # Catch any other unexpected exceptions during setup
617
+ logger.exception(
618
+ "Unexpected error during worker setup for %s/%s (%s)",
619
+ subject_id,
620
+ tract_name,
621
+ type(e).__name__,
622
+ )
623
+ results = {}
624
+ finally:
625
+ # Clean up the visualizer instance and force garbage collection
626
+ # This is critical to prevent memory leaks that could cause OOM kills
627
+ if visualizer is not None:
628
+ # Try to clean up any VTK objects
629
+ with contextlib.suppress(AttributeError, RuntimeError, TypeError):
630
+ del visualizer
631
+ # Clear results dict reference if it's large
632
+ if results:
633
+ # Keep only the minimal results (file paths, not data)
634
+ pass # Results dict should only contain paths, not large data
635
+
636
+ # Force garbage collection multiple times to handle circular references
637
+ # VTK objects can have complex reference cycles
638
+ for _ in range(3):
639
+ gc.collect()
640
+
641
+ return (subject_id, tract_name, results)
642
+
643
+
644
+ class TractographyVisualizationError(Exception):
645
+ """Base exception for tractography visualization errors."""
646
+
647
+
648
+ class InvalidInputError(TractographyVisualizationError):
649
+ """Raised when input data is invalid."""
650
+
651
+
652
+ class TractographyVisualizer:
653
+ """A class for visualizing diffusion tractography data.
654
+
655
+ This class provides methods for loading, processing, and visualizing
656
+ tractography data, including generating static images, animations, and
657
+ quality metrics.
658
+
659
+ Parameters
660
+ ----------
661
+ reference_image : str | Path, optional
662
+ Path to the reference T1-weighted image. Can be set later via
663
+ `set_reference_image()`.
664
+ output_directory : str | Path, optional
665
+ Default output directory for generated files. Can be set later via
666
+ `set_output_directory()`.
667
+ gif_size : tuple[int, int], optional
668
+ Size of generated GIFs in pixels. Default is (608, 608).
669
+ gif_duration : float, optional
670
+ Duration per frame in seconds. Default is 0.2.
671
+ gif_palette_size : int, optional
672
+ Color palette size for GIF optimization. Default is 64.
673
+ gif_frames : int, optional
674
+ Number of frames in rotation animation. Default is 60.
675
+ min_streamline_length : float, optional
676
+ Minimum streamline length for CCI calculation. Default is 40.0.
677
+ cci_threshold : float, optional
678
+ Minimum CCI value to keep streamlines. Default is 1.0.
679
+ afq_resample_points : int, optional
680
+ Number of points for AFQ resampling. Default is 100.
681
+ n_jobs : int, optional
682
+ Number of parallel jobs to run for processing multiple subjects/tracts.
683
+ Default is 1 (sequential processing). Use -1 to automatically determine
684
+ optimal number based on available resources (respects SLURM allocations
685
+ and OpenMP thread settings to prevent oversubscription).
686
+ Only used in `run_quality_check_workflow()`.
687
+
688
+ Note: When running under SLURM, this will automatically use
689
+ SLURM_CPUS_PER_TASK or SLURM_JOB_CPUS_PER_NODE. If OMP_NUM_THREADS is set,
690
+ it will divide the available CPUs by the number of OpenMP threads to
691
+ prevent resource contention.
692
+ max_memory_mb : float | None, optional
693
+ Maximum memory limit in MB for the process. If set, the process will be
694
+ killed by the OS if it exceeds this limit. This helps prevent OOM kills
695
+ by setting a hard limit. Default is None (no limit).
696
+
697
+ Note: This uses resource.setrlimit() which sets a virtual memory limit.
698
+ Once set, the limit cannot be increased (only decreased).
699
+
700
+ Examples
701
+ --------
702
+ Single subject usage:
703
+ >>> visualizer = TractographyVisualizer(
704
+ ... reference_image="path/to/t1w.nii.gz", output_directory="output/"
705
+ ... )
706
+ >>> visualizer.generate_videos(
707
+ ... tract_files=["tract1.trk", "tract2.trk"], ref_file="t1w.nii.gz"
708
+ ... )
709
+
710
+ Multiple subjects usage (initialize once, set data per subject):
711
+ >>> visualizer = TractographyVisualizer(output_directory="output/")
712
+ >>> # Process subject 1
713
+ >>> visualizer.generate_videos(
714
+ ... tract_files=["subj1_tract1.trk"], ref_file="subj1_t1w.nii.gz"
715
+ ... )
716
+ >>> # Process subject 2
717
+ >>> visualizer.generate_videos(
718
+ ... tract_files=["subj2_tract1.trk"], ref_file="subj2_t1w.nii.gz"
719
+ ... )
720
+ """
721
+
722
+ def __init__(
723
+ self,
724
+ reference_image: str | Path | None = None,
725
+ output_directory: str | Path | None = None,
726
+ *,
727
+ gif_size: tuple[int, int] = (608, 608),
728
+ gif_duration: float = 0.2,
729
+ gif_palette_size: int = 64,
730
+ gif_frames: int = 60,
731
+ min_streamline_length: float = 40.0,
732
+ cci_threshold: float = 1.0,
733
+ afq_resample_points: int = 100,
734
+ n_jobs: int = 1,
735
+ max_memory_mb: float | None = None,
736
+ ) -> None:
737
+ """Initialize the TractographyVisualizer."""
738
+ self._reference_image: Path | None = None
739
+ self._output_directory: Path | None = None
740
+ self.max_memory_mb: float | None = None
741
+
742
+ if reference_image is not None:
743
+ self.set_reference_image(reference_image)
744
+ if output_directory is not None:
745
+ self.set_output_directory(output_directory)
746
+
747
+ self.gif_size = gif_size
748
+ self.gif_duration = gif_duration
749
+ self.gif_palette_size = gif_palette_size
750
+ self.gif_frames = gif_frames
751
+ self.min_streamline_length = min_streamline_length
752
+ self.cci_threshold = cci_threshold
753
+ self.afq_resample_points = afq_resample_points
754
+ # Handle n_jobs: -1 means use all CPUs, otherwise use specified value
755
+ if n_jobs == -1:
756
+ self.n_jobs = _get_optimal_n_jobs()
757
+ else:
758
+ self.n_jobs = n_jobs
759
+
760
+ # Set memory limit if specified
761
+ if max_memory_mb is not None:
762
+ _set_memory_limit(max_memory_mb)
763
+ self.max_memory_mb = max_memory_mb
764
+ else:
765
+ self.max_memory_mb = None
766
+
767
+ def set_reference_image(self, reference_image: str | Path) -> None:
768
+ """Set the reference T1-weighted image.
769
+
770
+ Parameters
771
+ ----------
772
+ reference_image : str | Path
773
+ Path to the reference image file.
774
+
775
+ Raises
776
+ ------
777
+ FileNotFoundError
778
+ If the reference image file does not exist.
779
+ """
780
+ ref_path = Path(reference_image)
781
+ if not ref_path.exists():
782
+ raise FileNotFoundError(f"Reference image not found: {reference_image}")
783
+ self._reference_image = ref_path
784
+
785
+ def set_output_directory(self, output_directory: str | Path) -> None:
786
+ """Set the output directory for generated files.
787
+
788
+ Parameters
789
+ ----------
790
+ output_directory : str | Path
791
+ Path to the output directory. Will be created if it doesn't exist.
792
+
793
+ Raises
794
+ ------
795
+ OSError
796
+ If the directory cannot be created.
797
+ """
798
+ out_dir = Path(output_directory)
799
+ out_dir.mkdir(parents=True, exist_ok=True)
800
+ self._output_directory = out_dir
801
+
802
+ @property
803
+ def reference_image(self) -> Path | None:
804
+ """Get the current reference image path."""
805
+ return self._reference_image
806
+
807
+ @property
808
+ def output_directory(self) -> Path | None:
809
+ """Get the current output directory path."""
810
+ return self._output_directory
811
+
812
+ def get_glass_brain(self, t1w_img: str | Path | None = None) -> actor.Actor:
813
+ """Get the glass brain actor for the T1-weighted image.
814
+
815
+ Parameters
816
+ ----------
817
+ t1w_img : str | Path | None, optional
818
+ Path to the T1-weighted image. If None, uses the reference image
819
+ set during initialization or via `set_reference_image()`.
820
+
821
+ Returns
822
+ -------
823
+ actor.Actor
824
+ The glass brain actor.
825
+
826
+ Raises
827
+ ------
828
+ FileNotFoundError
829
+ If the image file is not found.
830
+ InvalidInputError
831
+ If no reference image is provided and none was set.
832
+ """
833
+ if t1w_img is None:
834
+ if self._reference_image is None:
835
+ raise InvalidInputError(
836
+ "No reference image provided. Set it via constructor or "
837
+ "set_reference_image() method, or pass it as an argument.",
838
+ )
839
+ t1w_img = self._reference_image
840
+
841
+ try:
842
+ mask_img = nib.load(str(t1w_img))
843
+ mask_data = mask_img.get_fdata() # type: ignore[attr-defined]
844
+ except (OSError, ValueError, RuntimeError) as e:
845
+ raise TractographyVisualizationError(
846
+ f"Failed to load glass brain from {t1w_img}: {e}",
847
+ ) from e
848
+ else:
849
+ return actor.contour_from_roi(mask_data, color=[0, 0, 0], opacity=0.05)
850
+
851
+ def _set_anatomical_camera(
852
+ self,
853
+ scene: window.Scene,
854
+ centroid: np.ndarray,
855
+ view_name: str,
856
+ *,
857
+ camera_distance: float | None = None,
858
+ bbox_size: np.ndarray | None = None,
859
+ ) -> None:
860
+ """Set camera position for standard anatomical views.
861
+
862
+ Wrapper around utils.set_anatomical_camera for class method compatibility.
863
+
864
+ Parameters
865
+ ----------
866
+ scene : window.Scene
867
+ The FURY scene to set the camera on.
868
+ centroid : np.ndarray
869
+ The centroid of the streamlines (3D coordinates).
870
+ view_name : str
871
+ Name of the view: "coronal", "axial", or "sagittal".
872
+ camera_distance : float | None, optional
873
+ Distance of camera from centroid. If None, calculated from bbox_size.
874
+ bbox_size : np.ndarray | None, optional
875
+ Bounding box size of streamlines. Used to calculate camera_distance if not provided.
876
+
877
+ Raises
878
+ ------
879
+ InvalidInputError
880
+ If view_name is not one of the standard anatomical views.
881
+ """
882
+ try:
883
+ set_anatomical_camera(
884
+ scene,
885
+ centroid,
886
+ view_name,
887
+ camera_distance=camera_distance,
888
+ bbox_size=bbox_size,
889
+ )
890
+ except ValueError as e:
891
+ raise InvalidInputError(str(e)) from e
892
+
893
+ def _create_scene(
894
+ self,
895
+ *,
896
+ ref_img: str | Path | None = None,
897
+ show_glass_brain: bool = True,
898
+ ) -> tuple[window.Scene, actor.Actor | None]:
899
+ """Create a new scene with optional glass brain.
900
+
901
+ Parameters
902
+ ----------
903
+ show_glass_brain : bool, optional
904
+ Whether to add glass brain to the scene. Default is True.
905
+ ref_img : str | Path | None, optional
906
+ Reference image for glass brain. If None, uses default.
907
+
908
+ Returns
909
+ -------
910
+ tuple[window.Scene, actor.Actor | None]
911
+ The scene and brain actor (if added, otherwise None).
912
+ """
913
+ scene = window.Scene()
914
+ scene.SetBackground(1, 1, 1)
915
+
916
+ brain_actor = None
917
+ if show_glass_brain:
918
+ brain_actor = self.get_glass_brain(ref_img)
919
+ scene.add(brain_actor)
920
+
921
+ return scene, brain_actor
922
+
923
+ def _create_streamline_actor(
924
+ self,
925
+ streamlines: Streamlines,
926
+ colors: np.ndarray | None = None,
927
+ ) -> actor.Actor:
928
+ """Create a streamline actor with optional colors, handling length mismatches.
929
+
930
+ Parameters
931
+ ----------
932
+ streamlines : Streamlines
933
+ The streamlines to visualize.
934
+ colors : np.ndarray | None, optional
935
+ Colors array. If provided, must match streamlines length or will be adjusted.
936
+
937
+ Returns
938
+ -------
939
+ actor.Actor
940
+ The streamline actor.
941
+ """
942
+ if colors is not None:
943
+ if len(colors) == len(streamlines):
944
+ return actor.line(streamlines, colors=colors)
945
+ # Handle length mismatch
946
+ if len(colors) < len(streamlines):
947
+ repeat_factor = len(streamlines) // len(colors) + 1
948
+ extended_colors = np.tile(colors, (repeat_factor, 1))[: len(streamlines)]
949
+ return actor.line(streamlines, colors=extended_colors)
950
+ return actor.line(streamlines, colors=colors[: len(streamlines)])
951
+ return actor.line(streamlines)
952
+
953
+ def _extract_tract_name(self, tract_file: str | Path) -> str:
954
+ """Extract tract name from file path (without extension).
955
+
956
+ Parameters
957
+ ----------
958
+ tract_file : str | Path
959
+ Path to tractography file.
960
+
961
+ Returns
962
+ -------
963
+ str
964
+ Tract name without extension.
965
+ """
966
+ return Path(tract_file).stem
967
+
968
+ def load_tract(
969
+ self,
970
+ tract_file: str | Path,
971
+ ref_img: str | Path | None = None,
972
+ ) -> actor.Actor:
973
+ """Load the tractography file and create an actor.
974
+
975
+ Parameters
976
+ ----------
977
+ tract_file : str | Path
978
+ Path to the tractography file (.trk).
979
+ ref_img : str | Path | None, optional
980
+ Path to the reference image. If None, uses the reference image
981
+ set during initialization or via `set_reference_image()`.
982
+
983
+ Returns
984
+ -------
985
+ actor.Actor
986
+ The tractography actor.
987
+
988
+ Raises
989
+ ------
990
+ FileNotFoundError
991
+ If the tract or reference image file is not found.
992
+ InvalidInputError
993
+ If no reference image is provided and none was set.
994
+ TractographyVisualizationError
995
+ If loading or transformation fails.
996
+ """
997
+ if ref_img is None:
998
+ if self._reference_image is None:
999
+ raise InvalidInputError(
1000
+ "No reference image provided. Set it via constructor or "
1001
+ "set_reference_image() method, or pass it as an argument.",
1002
+ )
1003
+ ref_img = self._reference_image
1004
+
1005
+ try:
1006
+ tract = load_trk(str(tract_file), "same", bbox_valid_check=False)
1007
+ tract.to_rasmm()
1008
+ ref_img_obj = nib.load(str(ref_img))
1009
+ tract_to_ref = transform_streamlines(
1010
+ tract.streamlines,
1011
+ np.linalg.inv(ref_img_obj.affine), # type: ignore[attr-defined] # type: ignore[attr-defined]
1012
+ )
1013
+ except (OSError, ValueError, RuntimeError) as e:
1014
+ raise TractographyVisualizationError(
1015
+ f"Failed to load tract from {tract_file}: {e}",
1016
+ ) from e
1017
+ else:
1018
+ return actor.line(tract_to_ref)
1019
+
1020
+ def weighted_afq(
1021
+ self,
1022
+ tract_file: str | Path,
1023
+ atlas_file: str | Path,
1024
+ metric_file: str | Path,
1025
+ ) -> np.ndarray:
1026
+ """Calculate weighted AFQ profile for tractography.
1027
+
1028
+ Parameters
1029
+ ----------
1030
+ tract_file : str | Path
1031
+ Path to the tractography file.
1032
+ atlas_file : str | Path
1033
+ Path to the atlas tractography file.
1034
+ metric_file : str | Path
1035
+ Path to the metric image file.
1036
+
1037
+ Returns
1038
+ -------
1039
+ np.ndarray
1040
+ The AFQ profile array.
1041
+
1042
+ Raises
1043
+ ------
1044
+ FileNotFoundError
1045
+ If any required file is not found.
1046
+ InvalidInputError
1047
+ If clustering fails or no centroids are found.
1048
+ TractographyVisualizationError
1049
+ If processing fails.
1050
+ """
1051
+ try:
1052
+ atlas_tract = load_trk(str(atlas_file), "same", bbox_valid_check=False)
1053
+ tract = load_trk(str(tract_file), "same", bbox_valid_check=False)
1054
+ metric_img = nib.load(str(metric_file))
1055
+ metric = metric_img.get_fdata() # type: ignore[attr-defined]
1056
+
1057
+ feature = ResampleFeature(nb_points=self.afq_resample_points)
1058
+ qb_metric = AveragePointwiseEuclideanMetric(feature)
1059
+ qb = QuickBundles(threshold=np.inf, metric=qb_metric)
1060
+ cluster_tract = qb.cluster(atlas_tract.streamlines)
1061
+
1062
+ if len(cluster_tract.centroids) == 0:
1063
+ raise InvalidInputError(
1064
+ "No centroids found in atlas tractography clustering.",
1065
+ )
1066
+
1067
+ standard_tract = cluster_tract.centroids[0]
1068
+ oriented_tract = orient_by_streamline(tract.streamlines, standard_tract)
1069
+ w_tract = gaussian_weights(oriented_tract)
1070
+ profile = afq_profile(
1071
+ metric,
1072
+ oriented_tract,
1073
+ affine=metric_img.affine, # type: ignore[attr-defined]
1074
+ weights=w_tract,
1075
+ )
1076
+
1077
+ except TractographyVisualizationError:
1078
+ raise
1079
+ except (OSError, ValueError, RuntimeError, IndexError) as e:
1080
+ raise TractographyVisualizationError(
1081
+ f"Failed to calculate weighted AFQ profile: {e}",
1082
+ ) from e
1083
+ else:
1084
+ return profile
1085
+
1086
+ def calc_cci(
1087
+ self,
1088
+ tract: StatefulTractogram,
1089
+ ref_img: str | Path | None = None,
1090
+ ) -> tuple[np.ndarray, StatefulTractogram]:
1091
+ """Calculate Cluster Confidence Index (CCI) for tractography.
1092
+
1093
+ Parameters
1094
+ ----------
1095
+ tract : StatefulTractogram
1096
+ The tractography data.
1097
+ ref_img : str | Path | None, optional
1098
+ Path to the reference image. If None, uses the reference image
1099
+ set during initialization or via `set_reference_image()`.
1100
+
1101
+ Returns
1102
+ -------
1103
+ tuple[np.ndarray, StatefulTractogram]
1104
+ A tuple containing:
1105
+ - CCI values as a numpy array
1106
+ - Filtered tractogram with streamlines above threshold
1107
+
1108
+ Raises
1109
+ ------
1110
+ InvalidInputError
1111
+ If the tractogram is empty or processing fails, or if no reference
1112
+ image is provided and none was set.
1113
+ TractographyVisualizationError
1114
+ If calculation fails.
1115
+ """
1116
+ if not tract.streamlines or len(tract.streamlines) == 0:
1117
+ raise InvalidInputError("Tractogram is empty.")
1118
+
1119
+ if ref_img is None:
1120
+ if self._reference_image is None:
1121
+ raise InvalidInputError(
1122
+ "No reference image provided. Set it via constructor or "
1123
+ "set_reference_image() method, or pass it as an argument.",
1124
+ )
1125
+ ref_img = self._reference_image
1126
+
1127
+ try:
1128
+ lengths = list(length(tract.streamlines))
1129
+ long_streamlines = Streamlines()
1130
+ for i, sl in enumerate(tract.streamlines):
1131
+ if lengths[i] > self.min_streamline_length:
1132
+ long_streamlines.append(sl)
1133
+
1134
+ if len(long_streamlines) == 0:
1135
+ raise InvalidInputError(
1136
+ f"No streamlines longer than {self.min_streamline_length}mm found.",
1137
+ )
1138
+
1139
+ cci = cluster_confidence(long_streamlines)
1140
+
1141
+ # Create boolean mask for streamlines above threshold
1142
+ keep_mask = cci >= self.cci_threshold
1143
+
1144
+ # Filter streamlines and CCI values to match
1145
+ keep_streamlines = Streamlines()
1146
+ for i, sl in enumerate(long_streamlines):
1147
+ if keep_mask[i]:
1148
+ keep_streamlines.append(sl)
1149
+
1150
+ # Filter CCI array to match kept streamlines
1151
+ keep_cci = cci[keep_mask]
1152
+ keep_tract = StatefulTractogram(keep_streamlines, nib.load(str(ref_img)), Space.RASMM)
1153
+
1154
+ except TractographyVisualizationError:
1155
+ raise
1156
+ except (OSError, ValueError, RuntimeError, IndexError) as e:
1157
+ raise TractographyVisualizationError(
1158
+ f"Failed to calculate CCI: {e}",
1159
+ ) from e
1160
+ else:
1161
+ return keep_cci, keep_tract
1162
+
1163
+ def generate_anatomical_views(
1164
+ self,
1165
+ tract_file: str | Path,
1166
+ ref_img: str | Path | None = None,
1167
+ *,
1168
+ views: list[str] | None = None,
1169
+ output_dir: str | Path | None = None,
1170
+ figure_size: tuple[int, int] = (800, 800),
1171
+ show_glass_brain: bool = True,
1172
+ ) -> dict[str, Path]:
1173
+ """Generate snapshots from standard anatomical views.
1174
+
1175
+ Creates static images from coronal, axial, and sagittal views of the tract.
1176
+
1177
+ Parameters
1178
+ ----------
1179
+ tract_file : str | Path
1180
+ Path to the tractography file.
1181
+ ref_img : str | Path | None, optional
1182
+ Path to the reference image. If None, uses the reference image
1183
+ set during initialization or via `set_reference_image()`.
1184
+ views : list[str] | None, optional
1185
+ List of views to generate. Options: "coronal", "axial", "sagittal".
1186
+ If None, generates all three views. Default is None.
1187
+ output_dir : str | Path | None, optional
1188
+ Output directory for generated images. If None, uses the output
1189
+ directory set during initialization.
1190
+ figure_size : tuple[int, int], optional
1191
+ Size of the output images in pixels. Default is (800, 800).
1192
+ show_glass_brain : bool, optional
1193
+ Whether to show the glass brain outline. Default is True.
1194
+
1195
+ Returns
1196
+ -------
1197
+ dict[str, Path]
1198
+ Dictionary mapping view names to their output file paths.
1199
+ Keys: "coronal", "axial", "sagittal".
1200
+
1201
+ Raises
1202
+ ------
1203
+ FileNotFoundError
1204
+ If required files are not found.
1205
+ InvalidInputError
1206
+ If no output directory is available or invalid view name.
1207
+ TractographyVisualizationError
1208
+ If image generation fails.
1209
+
1210
+ Examples
1211
+ --------
1212
+ >>> visualizer = TractographyVisualizer(
1213
+ ... reference_image="t1w.nii.gz", output_directory="output/"
1214
+ ... )
1215
+ >>> # Generate all views
1216
+ >>> views = visualizer.generate_anatomical_views("tract.trk")
1217
+ >>> # Generate specific views
1218
+ >>> views = visualizer.generate_anatomical_views(
1219
+ ... "tract.trk", views=["coronal", "axial"]
1220
+ ... )
1221
+ """
1222
+ # Use standard anatomical view angles from utils
1223
+ view_angles = ANATOMICAL_VIEW_ANGLES
1224
+
1225
+ # Determine which views to generate
1226
+ if views is None:
1227
+ views_to_generate = list(view_angles.keys())
1228
+ else:
1229
+ views_to_generate = views
1230
+ # Validate view names
1231
+ invalid_views = [v for v in views_to_generate if v not in view_angles]
1232
+ if invalid_views:
1233
+ raise InvalidInputError(
1234
+ f"Invalid view names: {invalid_views}. Valid options: {list(view_angles.keys())}",
1235
+ )
1236
+
1237
+ # Get output directory
1238
+ if output_dir is None:
1239
+ if self._output_directory is None:
1240
+ raise InvalidInputError(
1241
+ "No output directory provided. Set it via constructor or "
1242
+ "set_output_directory() method, or pass it as an argument.",
1243
+ )
1244
+ output_dir = self._output_directory
1245
+ else:
1246
+ output_dir = Path(output_dir)
1247
+ output_dir.mkdir(parents=True, exist_ok=True)
1248
+
1249
+ if ref_img is None:
1250
+ if self._reference_image is None:
1251
+ raise InvalidInputError(
1252
+ "No reference image provided. Set it via constructor or "
1253
+ "set_reference_image() method, or pass it as an argument.",
1254
+ )
1255
+ ref_img = self._reference_image
1256
+ else:
1257
+ ref_img = Path(ref_img)
1258
+
1259
+ tract_name = self._extract_tract_name(tract_file)
1260
+ generated_views: dict[str, Path] = {}
1261
+
1262
+ # Check if all output files already exist BEFORE loading tract
1263
+ # This prevents unnecessary memory usage when files are already generated
1264
+ all_files_exist = True
1265
+ for view_name in views_to_generate:
1266
+ output_image = output_dir / f"{tract_name}_{view_name}.png"
1267
+ if output_image.exists():
1268
+ generated_views[view_name] = output_image
1269
+ logger.debug("Skipping generation of %s (file already exists)", output_image)
1270
+ else:
1271
+ all_files_exist = False
1272
+
1273
+ # If all files exist, return early without loading tract
1274
+ if all_files_exist:
1275
+ return generated_views
1276
+
1277
+ try:
1278
+ # Load tract only if we need to generate at least one view
1279
+ tract = load_trk(str(tract_file), "same", bbox_valid_check=False)
1280
+ tract.to_rasmm()
1281
+ ref_img_obj = nib.load(str(ref_img))
1282
+ tract_streamlines = transform_streamlines(
1283
+ tract.streamlines,
1284
+ np.linalg.inv(ref_img_obj.affine), # type: ignore[attr-defined]
1285
+ )
1286
+
1287
+ # Calculate colors based on streamline directions using utility function
1288
+ streamline_colors = calculate_direction_colors(tract_streamlines)
1289
+
1290
+ # Get centroid using utility function
1291
+ centroid = calculate_centroid(tract_streamlines)
1292
+
1293
+ # Log tract statistics for debugging
1294
+ total_points = sum(len(sl) for sl in tract_streamlines)
1295
+ avg_points = total_points / len(tract_streamlines) if tract_streamlines else 0
1296
+ max_points = max(len(sl) for sl in tract_streamlines) if tract_streamlines else 0
1297
+ logger.info(
1298
+ "Tract statistics: %d streamlines, %d total points, "
1299
+ "%.1f avg points/streamline, %d max points/streamline",
1300
+ len(tract_streamlines),
1301
+ total_points,
1302
+ avg_points,
1303
+ max_points,
1304
+ )
1305
+
1306
+ # Warn if tract is very dense (high risk of std::bad_alloc)
1307
+ if total_points > DENSE_TRACT_POINT_THRESHOLD:
1308
+ logger.warning(
1309
+ "Very dense tract detected (%d total points). "
1310
+ "This may cause std::bad_alloc errors due to memory fragmentation. "
1311
+ "Consider: (1) resampling streamlines to reduce points, "
1312
+ "(2) filtering with higher cci_threshold, (3) reducing image resolution.",
1313
+ total_points,
1314
+ )
1315
+ if max_points > VERY_LONG_STREAMLINE_POINT_THRESHOLD:
1316
+ logger.warning(
1317
+ "Very long streamline detected (%d points). "
1318
+ "VTK may need large contiguous memory allocation which can fail due to fragmentation.",
1319
+ max_points,
1320
+ )
1321
+
1322
+ # Clean up loaded objects that are no longer needed
1323
+ del tract, ref_img_obj
1324
+
1325
+ # Generate each requested view
1326
+ for view_name in views_to_generate:
1327
+ output_image = output_dir / f"{tract_name}_{view_name}.png"
1328
+
1329
+ # Skip if file already exists (already added to generated_views above)
1330
+ if output_image.exists():
1331
+ continue
1332
+
1333
+ # Check memory before creating VTK actor
1334
+ estimated_memory = _estimate_actor_memory_mb(tract_streamlines, figure_size)
1335
+ if not _check_memory_available(estimated_memory, safety_margin=0.3):
1336
+ logger.warning(
1337
+ "Insufficient memory to create actor for view %s. Skipping.",
1338
+ view_name,
1339
+ )
1340
+ continue
1341
+
1342
+ # Create scene using helper method
1343
+ scene, _ = self._create_scene(ref_img=ref_img, show_glass_brain=show_glass_brain)
1344
+
1345
+ # Create actor with original streamlines using helper method
1346
+ try:
1347
+ tract_actor = self._create_streamline_actor(tract_streamlines, streamline_colors)
1348
+ scene.add(tract_actor)
1349
+ except RuntimeError as e:
1350
+ # Catch std::bad_alloc and other VTK errors
1351
+ error_msg = str(e).lower()
1352
+ if "bad_alloc" in error_msg or "memory" in error_msg or "allocation" in error_msg:
1353
+ logger.exception(
1354
+ "VTK memory allocation failed for view %s (likely std::bad_alloc). "
1355
+ "Tract has %d streamlines with %d total points. "
1356
+ "Try reducing image resolution or filtering streamlines.",
1357
+ view_name,
1358
+ len(tract_streamlines),
1359
+ sum(len(sl) for sl in tract_streamlines),
1360
+ )
1361
+ scene.clear()
1362
+ continue
1363
+ raise
1364
+
1365
+ # Set camera position for anatomical view (no streamline rotation needed)
1366
+ # Calculate bbox for camera distance using utility function
1367
+ bbox_size = calculate_bbox_size(tract_streamlines)
1368
+
1369
+ # Use helper method to set camera
1370
+ self._set_anatomical_camera(
1371
+ scene,
1372
+ centroid,
1373
+ view_name,
1374
+ bbox_size=bbox_size,
1375
+ )
1376
+
1377
+ # Record the scene
1378
+ window.record(
1379
+ scene=scene,
1380
+ out_path=str(output_image),
1381
+ size=figure_size,
1382
+ )
1383
+ scene.clear()
1384
+ del tract_actor
1385
+
1386
+ generated_views[view_name] = output_image
1387
+
1388
+ # Clean up remaining large objects
1389
+ del tract_streamlines, streamline_colors
1390
+ gc.collect()
1391
+
1392
+ except TractographyVisualizationError:
1393
+ raise
1394
+ except (OSError, ValueError, RuntimeError) as e:
1395
+ raise TractographyVisualizationError(
1396
+ f"Failed to generate anatomical views: {e}",
1397
+ ) from e
1398
+ else:
1399
+ return generated_views
1400
+
1401
+ def generate_atlas_views(
1402
+ self,
1403
+ atlas_file: str | Path,
1404
+ *,
1405
+ atlas_ref_img: str | Path | None = None,
1406
+ ref_img: str | Path | None = None, # Alias for atlas_ref_img
1407
+ flip_lr: bool = False,
1408
+ views: list[str] | None = None,
1409
+ output_dir: str | Path | None = None,
1410
+ figure_size: tuple[int, int] = (800, 800),
1411
+ show_glass_brain: bool = True,
1412
+ atlas_name: str | None = None,
1413
+ ) -> dict[str, Path]:
1414
+ """Generate anatomical views for an atlas tract.
1415
+
1416
+ Creates static images from coronal, axial, and sagittal views of the atlas
1417
+ tract. This is useful for comparing subject tracts to atlas tracts using
1418
+ the same viewing angles.
1419
+
1420
+ Parameters
1421
+ ----------
1422
+ atlas_file : str | Path
1423
+ Path to the atlas tractography file.
1424
+ atlas_ref_img : str | Path | None, optional
1425
+ Path to the reference image that matches the atlas coordinate space
1426
+ (e.g., MNI template if atlas is in MNI space).
1427
+ This is important if the atlas is in a different space (e.g., MNI)
1428
+ than the subject reference image.
1429
+ flip_lr: bool, optional
1430
+ Whether to flip left-right (X-axis) when transforming atlas.
1431
+ This may be needed for some coordinate conventions or file formats
1432
+ where the left-right orientation differs. Try this if the atlas
1433
+ appears on the wrong side compared to the subject. Default is False.
1434
+ views : list[str] | None, optional
1435
+ List of views to generate. Options: "coronal", "axial", "sagittal".
1436
+ If None, generates all three views. Default is None.
1437
+ output_dir : str | Path | None, optional
1438
+ Output directory for generated images. If None, uses the output
1439
+ directory set during initialization.
1440
+ figure_size : tuple[int, int], optional
1441
+ Size of the output images in pixels. Default is (800, 800).
1442
+ show_glass_brain : bool, optional
1443
+ Whether to show the glass brain outline. Default is True.
1444
+ atlas_name : str | None, optional
1445
+ Name prefix for output files. If None, uses the stem of atlas_file.
1446
+ Default is None.
1447
+
1448
+ Returns
1449
+ -------
1450
+ dict[str, Path]
1451
+ Dictionary mapping view names to their output file paths.
1452
+ Keys: "coronal", "axial", "sagittal".
1453
+
1454
+ Raises
1455
+ ------
1456
+ FileNotFoundError
1457
+ If required files are not found.
1458
+ InvalidInputError
1459
+ If no output directory is available or invalid view name.
1460
+ TractographyVisualizationError
1461
+ If image generation fails.
1462
+
1463
+ Examples
1464
+ --------
1465
+ >>> visualizer = TractographyVisualizer(
1466
+ ... reference_image="t1w.nii.gz", output_directory="output/"
1467
+ ... )
1468
+ >>> # Generate all atlas views
1469
+ >>> atlas_views = visualizer.generate_atlas_views("atlas_tract.trk")
1470
+ >>> # Generate specific views
1471
+ >>> atlas_views = visualizer.generate_atlas_views(
1472
+ ... "atlas_tract.trk", views=["coronal", "axial"]
1473
+ ... )
1474
+ """
1475
+ # Determine reference image for atlas coordinate transformation
1476
+ # If atlas is in different space (e.g., MNI), use atlas_ref_img
1477
+ # Support both atlas_ref_img and ref_img (alias) for backward compatibility
1478
+ if atlas_ref_img is None and ref_img is not None:
1479
+ atlas_ref_img = ref_img
1480
+ atlas_ref_img = self._reference_image if atlas_ref_img is None else Path(atlas_ref_img)
1481
+
1482
+ # Use standard anatomical view angles from utils
1483
+ view_angles = ANATOMICAL_VIEW_ANGLES
1484
+
1485
+ # Determine which views to generate
1486
+ if views is None:
1487
+ views_to_generate = list(view_angles.keys())
1488
+ else:
1489
+ views_to_generate = views
1490
+ invalid_views = [v for v in views_to_generate if v not in view_angles]
1491
+ if invalid_views:
1492
+ raise InvalidInputError(
1493
+ f"Invalid view names: {invalid_views}. Valid options: {list(view_angles.keys())}",
1494
+ )
1495
+
1496
+ # Get output directory
1497
+ if output_dir is None:
1498
+ if self._output_directory is None:
1499
+ raise InvalidInputError(
1500
+ "No output directory provided. Set it via constructor or "
1501
+ "set_output_directory() method, or pass it as an argument.",
1502
+ )
1503
+ output_dir = self._output_directory
1504
+ else:
1505
+ output_dir = Path(output_dir)
1506
+ output_dir.mkdir(parents=True, exist_ok=True)
1507
+
1508
+ # Use provided atlas_name or derive from file
1509
+ atlas_name = self._extract_tract_name(atlas_file) if atlas_name is None else str(atlas_name)
1510
+
1511
+ generated_views: dict[str, Path] = {}
1512
+
1513
+ # Check if all output files already exist BEFORE loading atlas tract
1514
+ # This prevents unnecessary memory usage when files are already generated
1515
+ all_files_exist = True
1516
+ for view_name in views_to_generate:
1517
+ output_image = output_dir / f"{atlas_name}_atlas_{view_name}.png"
1518
+ if output_image.exists():
1519
+ generated_views[view_name] = output_image
1520
+ logger.debug("Skipping generation of %s (file already exists)", output_image)
1521
+ else:
1522
+ all_files_exist = False
1523
+
1524
+ # If all files exist, return early without loading atlas tract
1525
+ if all_files_exist:
1526
+ return generated_views
1527
+
1528
+ try:
1529
+ # Load atlas tract only if we need to generate at least one view
1530
+ atlas_tract = load_trk(str(atlas_file), "same", bbox_valid_check=False)
1531
+ atlas_tract.to_rasmm()
1532
+
1533
+ # Load reference image
1534
+ atlas_ref_img_obj = nib.load(str(atlas_ref_img))
1535
+
1536
+ # Transform atlas streamlines to visualization reference space
1537
+ # The atlas tract is already in RASMM after to_rasmm(), so we transform
1538
+ # directly to the visualization reference space
1539
+ atlas_streamlines = transform_streamlines(
1540
+ atlas_tract.streamlines,
1541
+ np.linalg.inv(atlas_ref_img_obj.affine), # type: ignore[attr-defined]
1542
+ )
1543
+
1544
+ # Apply left-right flip if needed (common when atlas is in MNI space)
1545
+ if flip_lr:
1546
+ # Flip X-axis (left-right) by negating X coordinates
1547
+ # This is needed when MNI and native space have different L/R conventions
1548
+ atlas_streamlines = Streamlines(
1549
+ [np.column_stack([-sl[:, 0], sl[:, 1], sl[:, 2]]) for sl in atlas_streamlines],
1550
+ )
1551
+
1552
+ # Also update centroid after flip
1553
+ all_points = np.vstack([np.array(sl) for sl in atlas_streamlines])
1554
+ centroid = np.mean(all_points, axis=0)
1555
+
1556
+ # Calculate colors based on streamline directions using utility function
1557
+ streamline_colors = calculate_direction_colors(atlas_streamlines)
1558
+
1559
+ # Get centroid using utility function (recalculate after any transformations)
1560
+ centroid = calculate_centroid(atlas_streamlines)
1561
+
1562
+ # Generate each requested view
1563
+ for view_name in views_to_generate:
1564
+ output_image = output_dir / f"{atlas_name}_atlas_{view_name}.png"
1565
+
1566
+ # Skip if file already exists (already added to generated_views above)
1567
+ if output_image.exists():
1568
+ continue
1569
+
1570
+ # Create scene using helper method
1571
+ scene, _ = self._create_scene(ref_img=atlas_ref_img, show_glass_brain=show_glass_brain)
1572
+
1573
+ # Create actor with original streamlines using helper method
1574
+ tract_actor = self._create_streamline_actor(atlas_streamlines, streamline_colors)
1575
+ scene.add(tract_actor)
1576
+
1577
+ # Set camera position for anatomical view using utility function
1578
+ bbox_size = calculate_bbox_size(atlas_streamlines)
1579
+
1580
+ # Use helper method to set camera
1581
+ self._set_anatomical_camera(
1582
+ scene,
1583
+ centroid,
1584
+ view_name,
1585
+ bbox_size=bbox_size,
1586
+ )
1587
+
1588
+ # Record the scene (this can also fail with std::bad_alloc)
1589
+ try:
1590
+ window.record(
1591
+ scene=scene,
1592
+ out_path=str(output_image),
1593
+ size=figure_size,
1594
+ )
1595
+ generated_views[view_name] = output_image
1596
+ except RuntimeError as e:
1597
+ # Catch std::bad_alloc and other VTK errors during rendering
1598
+ error_msg = str(e).lower()
1599
+ if "bad_alloc" in error_msg or "memory" in error_msg or "allocation" in error_msg:
1600
+ logger.exception(
1601
+ "VTK memory allocation failed during rendering for view %s (likely std::bad_alloc). "
1602
+ "Tract has %d streamlines with %d total points. "
1603
+ "Try reducing image resolution (figure_size) or processing with n_jobs=1.",
1604
+ view_name,
1605
+ len(atlas_streamlines),
1606
+ sum(len(sl) for sl in atlas_streamlines),
1607
+ )
1608
+ # Clean up and skip this view
1609
+ scene.clear()
1610
+ del tract_actor
1611
+ continue
1612
+ raise
1613
+
1614
+ scene.clear()
1615
+ del tract_actor
1616
+
1617
+ # Clean up large objects
1618
+ del atlas_tract, atlas_ref_img_obj, atlas_streamlines, streamline_colors
1619
+ gc.collect()
1620
+
1621
+ except TractographyVisualizationError:
1622
+ raise
1623
+ except (OSError, ValueError, RuntimeError) as e:
1624
+ raise TractographyVisualizationError(
1625
+ f"Failed to generate atlas views: {e}",
1626
+ ) from e
1627
+ else:
1628
+ return generated_views
1629
+
1630
+ def plot_cci(
1631
+ self,
1632
+ cci: np.ndarray,
1633
+ keep_tract: StatefulTractogram,
1634
+ hist_file: str | Path,
1635
+ ref_img: str | Path | None = None,
1636
+ *,
1637
+ views: list[str] | None = None,
1638
+ output_dir: str | Path | None = None,
1639
+ figure_size: tuple[int, int] = (800, 800),
1640
+ show_glass_brain: bool = True,
1641
+ bins: int = 100,
1642
+ ) -> dict[str, Path]:
1643
+ """Plot CCI with anatomical views.
1644
+
1645
+ Generates anatomical views (coronal, axial, sagittal) with streamlines
1646
+ colored by CCI values, along with a histogram plot.
1647
+
1648
+ Parameters
1649
+ ----------
1650
+ cci : np.ndarray
1651
+ CCI values array (one per streamline in keep_tract).
1652
+ keep_tract : StatefulTractogram
1653
+ Filtered tractogram to visualize (should match CCI array length).
1654
+ hist_file : str | Path
1655
+ Path to save the histogram plot.
1656
+ ref_img : str | Path | None, optional
1657
+ Path to the reference image. If None, uses the reference image
1658
+ set during initialization.
1659
+ views : list[str] | None, optional
1660
+ List of views to generate. Options: "coronal", "axial", "sagittal".
1661
+ If None, generates all three views. Default is None.
1662
+ output_dir : str | Path | None, optional
1663
+ Output directory for generated images. If None, uses the output
1664
+ directory set during initialization.
1665
+ figure_size : tuple[int, int], optional
1666
+ Size of the output images in pixels. Default is (800, 800).
1667
+ show_glass_brain : bool, optional
1668
+ Whether to show the glass brain outline. Default is True.
1669
+ bins : int, optional
1670
+ Number of bins for histogram. Default is 100.
1671
+
1672
+ Returns
1673
+ -------
1674
+ dict[str, Path]
1675
+ Dictionary mapping view names to their output file paths.
1676
+ Keys: "coronal", "axial", "sagittal", "histogram".
1677
+
1678
+ Raises
1679
+ ------
1680
+ InvalidInputError
1681
+ If CCI array is empty or invalid, or length mismatch with tract.
1682
+ TractographyVisualizationError
1683
+ If plotting fails.
1684
+ """
1685
+ if len(cci) == 0:
1686
+ raise InvalidInputError("CCI array is empty.")
1687
+
1688
+ # Get output directory
1689
+ if output_dir is None:
1690
+ if self._output_directory is None:
1691
+ raise InvalidInputError(
1692
+ "No output directory provided. Set it via constructor or "
1693
+ "set_output_directory() method, or pass it as an argument.",
1694
+ )
1695
+ output_dir = self._output_directory
1696
+ else:
1697
+ output_dir = Path(output_dir)
1698
+ output_dir.mkdir(parents=True, exist_ok=True)
1699
+
1700
+ if ref_img is None:
1701
+ if self._reference_image is None:
1702
+ raise InvalidInputError(
1703
+ "No reference image provided. Set it via constructor or "
1704
+ "set_reference_image() method, or pass it as an argument.",
1705
+ )
1706
+ ref_img = self._reference_image
1707
+ else:
1708
+ ref_img = Path(ref_img)
1709
+
1710
+ # Use standard anatomical view angles from utils
1711
+ view_angles = ANATOMICAL_VIEW_ANGLES
1712
+
1713
+ # Determine which views to generate
1714
+ if views is None:
1715
+ views_to_generate = list(view_angles.keys())
1716
+ else:
1717
+ views_to_generate = views
1718
+ invalid_views = [v for v in views_to_generate if v not in view_angles]
1719
+ if invalid_views:
1720
+ raise InvalidInputError(
1721
+ f"Invalid view names: {invalid_views}. Valid options: {list(view_angles.keys())}",
1722
+ )
1723
+
1724
+ generated_views: dict[str, Path] = {}
1725
+
1726
+ try:
1727
+ # Create histogram
1728
+ hist_path = Path(hist_file)
1729
+ hist_path.parent.mkdir(parents=True, exist_ok=True)
1730
+
1731
+ # Skip histogram if file already exists
1732
+ if not hist_path.exists():
1733
+ fig, ax = plt.subplots(1, figsize=(8, 6))
1734
+ ax.hist(cci, bins=bins, histtype="step")
1735
+ ax.set_xlabel("CCI")
1736
+ ax.set_ylabel("# streamlines")
1737
+ ax.set_title("CCI Distribution")
1738
+ ax.grid(visible=True, alpha=0.3)
1739
+ fig.savefig(str(hist_path), dpi=150, bbox_inches="tight")
1740
+ plt.close(fig)
1741
+ else:
1742
+ logger.debug("Skipping generation of %s (file already exists)", hist_path)
1743
+
1744
+ generated_views["histogram"] = hist_path
1745
+
1746
+ # Check if all view files already exist BEFORE processing tract
1747
+ # This prevents unnecessary memory usage when files are already generated
1748
+ all_view_files_exist = True
1749
+ for view_name in views_to_generate:
1750
+ output_image = output_dir / f"cci_{view_name}.png"
1751
+ if output_image.exists():
1752
+ generated_views[view_name] = output_image
1753
+ logger.debug("Skipping generation of %s (file already exists)", output_image)
1754
+ else:
1755
+ all_view_files_exist = False
1756
+
1757
+ # If all view files exist, return early without processing tract streamlines
1758
+ if all_view_files_exist:
1759
+ return generated_views
1760
+
1761
+ # Transform tract streamlines to reference space (only if needed)
1762
+ ref_img_obj = nib.load(str(ref_img))
1763
+ tract_streamlines = transform_streamlines(
1764
+ keep_tract.streamlines,
1765
+ np.linalg.inv(ref_img_obj.affine), # type: ignore[attr-defined]
1766
+ )
1767
+
1768
+ # Validate CCI array matches tract (critical for memory safety)
1769
+ if len(cci) != len(tract_streamlines):
1770
+ raise InvalidInputError(
1771
+ f"CCI array length ({len(cci)}) does not match "
1772
+ f"tractogram streamlines ({len(tract_streamlines)}). "
1773
+ f"This can cause memory issues.",
1774
+ )
1775
+
1776
+ # Calculate CCI colors for each streamline
1777
+ # Use the same color scheme as the original CCI visualization
1778
+ cci_min = float(cci.min())
1779
+ cci_max = float(cci.max())
1780
+
1781
+ # Create lookup table with same parameters as original
1782
+ hue = [0.5, 1]
1783
+ saturation = [0.0, 1.0]
1784
+ lut_cmap = actor.colormap_lookup_table(
1785
+ scale_range=(cci_min, cci_max / 4),
1786
+ hue_range=hue,
1787
+ saturation_range=saturation,
1788
+ )
1789
+
1790
+ bar = actor.scalar_bar(lookup_table=lut_cmap)
1791
+
1792
+ # Get centroid using utility function
1793
+ centroid = calculate_centroid(tract_streamlines)
1794
+ gc.collect()
1795
+
1796
+ # Generate each requested view
1797
+ for view_name in views_to_generate:
1798
+ output_image = output_dir / f"cci_{view_name}.png"
1799
+
1800
+ # Skip if file already exists (already added to generated_views above)
1801
+ if output_image.exists():
1802
+ continue
1803
+
1804
+ # Create scene
1805
+ scene = window.Scene()
1806
+ scene.SetBackground(0.5, 0.5, 0.5)
1807
+ scene.add(bar)
1808
+
1809
+ # Add glass brain if requested (no rotation needed - camera handles view)
1810
+ brain_actor = None
1811
+ if show_glass_brain:
1812
+ brain_actor = self.get_glass_brain(ref_img)
1813
+ scene.add(brain_actor)
1814
+
1815
+ # Check memory before creating large VTK actor
1816
+ estimated_memory = _estimate_actor_memory_mb(tract_streamlines, figure_size)
1817
+ if not _check_memory_available(estimated_memory, safety_margin=0.3):
1818
+ logger.warning(
1819
+ "Insufficient memory to create actor for %d streamlines. Skipping view %s",
1820
+ len(tract_streamlines),
1821
+ view_name,
1822
+ )
1823
+ continue
1824
+
1825
+ # Use CCI array directly with original streamlines (no rotation - camera handles view)
1826
+ try:
1827
+ tract_actor = actor.line(
1828
+ tract_streamlines,
1829
+ colors=cci,
1830
+ linewidth=0.1,
1831
+ lookup_colormap=lut_cmap,
1832
+ )
1833
+ scene.add(tract_actor)
1834
+ except RuntimeError as e:
1835
+ # Catch std::bad_alloc and other VTK errors
1836
+ error_msg = str(e).lower()
1837
+ if "bad_alloc" in error_msg or "memory" in error_msg or "allocation" in error_msg:
1838
+ logger.exception(
1839
+ "VTK memory allocation failed for view %s (likely std::bad_alloc). "
1840
+ "Try reducing tract size, image resolution, or using n_jobs=1.",
1841
+ view_name,
1842
+ )
1843
+ # Clean up and skip this view
1844
+ scene.clear()
1845
+ continue
1846
+ raise
1847
+
1848
+ # Set camera position for anatomical view using utility function
1849
+ bbox_size = calculate_bbox_size(tract_streamlines)
1850
+
1851
+ # Use helper method to set camera
1852
+ self._set_anatomical_camera(
1853
+ scene,
1854
+ centroid,
1855
+ view_name,
1856
+ bbox_size=bbox_size,
1857
+ )
1858
+
1859
+ # Record the scene (this can also fail with std::bad_alloc)
1860
+ try:
1861
+ window.record(
1862
+ scene=scene,
1863
+ out_path=str(output_image),
1864
+ size=figure_size,
1865
+ )
1866
+ generated_views[view_name] = output_image
1867
+ except RuntimeError as e:
1868
+ # Catch std::bad_alloc and other VTK errors during rendering
1869
+ error_msg = str(e).lower()
1870
+ if "bad_alloc" in error_msg or "memory" in error_msg or "allocation" in error_msg:
1871
+ logger.exception(
1872
+ "VTK memory allocation failed during rendering for view %s (likely std::bad_alloc). "
1873
+ "Tract has %d streamlines with %d total points. "
1874
+ "Try reducing image resolution (figure_size) or processing with n_jobs=1.",
1875
+ view_name,
1876
+ len(tract_streamlines),
1877
+ sum(len(sl) for sl in tract_streamlines),
1878
+ )
1879
+ # Clean up and skip this view
1880
+ scene.clear()
1881
+ del tract_actor
1882
+ if brain_actor is not None:
1883
+ del brain_actor
1884
+ del scene
1885
+ gc.collect()
1886
+ continue
1887
+ raise
1888
+
1889
+ # Explicitly clean up scene and actors to free memory
1890
+ # VTK/FURY objects can hold circular references, so explicit cleanup is critical
1891
+ scene.clear()
1892
+ del tract_actor
1893
+ if brain_actor is not None:
1894
+ del brain_actor
1895
+ del scene
1896
+
1897
+ # Force garbage collection between views to free memory
1898
+ # This is critical for large tractograms to prevent OOM kills
1899
+ gc.collect()
1900
+
1901
+ generated_views[view_name] = output_image
1902
+
1903
+ # Clean up large objects after all views are generated
1904
+ del keep_tract, tract_streamlines, cci
1905
+ gc.collect()
1906
+
1907
+ except TractographyVisualizationError:
1908
+ raise
1909
+ except (OSError, ValueError, RuntimeError) as e:
1910
+ raise TractographyVisualizationError(f"Failed to plot CCI: {e}") from e
1911
+ else:
1912
+ return generated_views
1913
+
1914
+ def plot_afq(
1915
+ self,
1916
+ metric_file: str | Path,
1917
+ metric_str: str,
1918
+ tract_file: str | Path,
1919
+ atlas_file: str | Path,
1920
+ ref_img: str | Path | None = None,
1921
+ *,
1922
+ views: list[str] | None = None,
1923
+ output_dir: str | Path | None = None,
1924
+ figure_size: tuple[int, int] = (800, 800),
1925
+ show_glass_brain: bool = True,
1926
+ colormap: str = "Spectral",
1927
+ ) -> dict[str, Path]:
1928
+ """Plot AFQ profile with anatomical views.
1929
+
1930
+ Generates anatomical views (coronal, axial, sagittal) with streamlines
1931
+ colored by AFQ profile values.
1932
+
1933
+ Parameters
1934
+ ----------
1935
+ metric_file : str | Path
1936
+ Path to the metric image file.
1937
+ metric_str : str
1938
+ Name of the metric (e.g., "FA", "MD") for labeling.
1939
+ tract_file : str | Path
1940
+ Path to the tractography file.
1941
+ atlas_file : str | Path
1942
+ Path to the atlas tractography file.
1943
+ ref_img : str | Path | None, optional
1944
+ Path to the reference image. If None, uses the reference image
1945
+ set during initialization.
1946
+ views : list[str] | None, optional
1947
+ List of views to generate. Options: "coronal", "axial", "sagittal".
1948
+ If None, generates all three views. Default is None.
1949
+ output_dir : str | Path | None, optional
1950
+ Output directory for generated images. If None, uses the output
1951
+ directory set during initialization.
1952
+ figure_size : tuple[int, int], optional
1953
+ Size of the output images in pixels. Default is (800, 800).
1954
+ show_glass_brain : bool, optional
1955
+ Whether to show the glass brain outline. Default is True.
1956
+ colormap : str, optional
1957
+ Name of the colormap to use for AFQ profile values.
1958
+ Default is "Spectral".
1959
+
1960
+ Returns
1961
+ -------
1962
+ dict[str, Path]
1963
+ Dictionary mapping view names to their output file paths.
1964
+ Also includes "profile_plot" key for the profile line plot.
1965
+
1966
+ Raises
1967
+ ------
1968
+ FileNotFoundError
1969
+ If required files are not found.
1970
+ InvalidInputError
1971
+ If no output directory is available.
1972
+ TractographyVisualizationError
1973
+ If image generation fails.
1974
+ """
1975
+ # Calculate AFQ profile
1976
+ profile = self.weighted_afq(tract_file, atlas_file, metric_file)
1977
+
1978
+ # Use standard anatomical view angles from utils
1979
+ view_angles = ANATOMICAL_VIEW_ANGLES
1980
+
1981
+ # Determine which views to generate
1982
+ if views is None:
1983
+ views_to_generate = list(view_angles.keys())
1984
+ else:
1985
+ views_to_generate = views
1986
+ invalid_views = [v for v in views_to_generate if v not in view_angles]
1987
+ if invalid_views:
1988
+ raise InvalidInputError(
1989
+ f"Invalid view names: {invalid_views}. Valid options: {list(view_angles.keys())}",
1990
+ )
1991
+
1992
+ # Get output directory
1993
+ if output_dir is None:
1994
+ if self._output_directory is None:
1995
+ raise InvalidInputError(
1996
+ "No output directory provided. Set it via constructor or "
1997
+ "set_output_directory() method, or pass it as an argument.",
1998
+ )
1999
+ output_dir = self._output_directory
2000
+ else:
2001
+ output_dir = Path(output_dir)
2002
+ output_dir.mkdir(parents=True, exist_ok=True)
2003
+
2004
+ if ref_img is None:
2005
+ if self._reference_image is None:
2006
+ raise InvalidInputError(
2007
+ "No reference image provided. Set it via constructor or "
2008
+ "set_reference_image() method, or pass it as an argument.",
2009
+ )
2010
+ ref_img = self._reference_image
2011
+ else:
2012
+ ref_img = Path(ref_img)
2013
+
2014
+ tract_name = self._extract_tract_name(tract_file)
2015
+ generated_views: dict[str, Path] = {}
2016
+
2017
+ # Check profile plot path
2018
+ profile_plot_path = output_dir / f"{tract_name}_{metric_str}_profile.png"
2019
+ if profile_plot_path.exists():
2020
+ logger.debug("Skipping generation of %s (file already exists)", profile_plot_path)
2021
+ generated_views["profile_plot"] = profile_plot_path
2022
+
2023
+ # Check if all view files already exist BEFORE loading tract
2024
+ # This prevents unnecessary memory usage when files are already generated
2025
+ all_view_files_exist = True
2026
+ for view_name in views_to_generate:
2027
+ output_image = output_dir / f"{tract_name}_{metric_str}_{view_name}.png"
2028
+ if output_image.exists():
2029
+ generated_views[view_name] = output_image
2030
+ logger.debug("Skipping generation of %s (file already exists)", output_image)
2031
+ else:
2032
+ all_view_files_exist = False
2033
+
2034
+ # If all view files exist and profile plot exists, return early without loading tract
2035
+ if all_view_files_exist and profile_plot_path.exists():
2036
+ return generated_views
2037
+
2038
+ try:
2039
+ # Load tract only if we need to generate at least one view or profile plot
2040
+ tract = load_trk(str(tract_file), "same", bbox_valid_check=False)
2041
+ tract.to_rasmm()
2042
+ ref_img_obj = nib.load(str(ref_img))
2043
+ tract_streamlines = transform_streamlines(
2044
+ tract.streamlines,
2045
+ np.linalg.inv(ref_img_obj.affine), # type: ignore[attr-defined]
2046
+ )
2047
+
2048
+ # Calculate AFQ profile colors for each streamline
2049
+ # Store per-point colors for each streamline
2050
+ streamline_point_colors = []
2051
+ for sl in tract_streamlines:
2052
+ sl_array = np.array(sl)
2053
+ # Interpolate profile values to match streamline points
2054
+ interpolated_values = np.interp(
2055
+ np.linspace(0, 1, len(sl_array)),
2056
+ np.linspace(0, 1, len(profile)),
2057
+ profile,
2058
+ )
2059
+ # Create colormap colors for each point
2060
+ point_colors = create_colormap(interpolated_values, name=colormap)
2061
+ streamline_point_colors.append(point_colors)
2062
+
2063
+ # Get centroid using utility function
2064
+ centroid = calculate_centroid(tract_streamlines)
2065
+
2066
+ # Generate each requested view
2067
+ for view_name in views_to_generate:
2068
+ output_image = output_dir / f"{tract_name}_{metric_str}_{view_name}.png"
2069
+
2070
+ # Skip if file already exists (already added to generated_views above)
2071
+ if output_image.exists():
2072
+ continue
2073
+
2074
+ # Create scene using helper method
2075
+ scene, _ = self._create_scene(ref_img=ref_img, show_glass_brain=show_glass_brain)
2076
+
2077
+ # Create actors with AFQ profile colors using original streamlines
2078
+ # Colors are already calculated per point, so we can use them directly
2079
+ for _i, (sl, point_colors) in enumerate(
2080
+ zip(tract_streamlines, streamline_point_colors),
2081
+ ):
2082
+ line_actor = actor.line([sl], colors=point_colors)
2083
+ scene.add(line_actor)
2084
+
2085
+ # Set camera position for anatomical view using utility function
2086
+ bbox_size = calculate_bbox_size(tract_streamlines)
2087
+
2088
+ # Use helper method to set camera
2089
+ self._set_anatomical_camera(
2090
+ scene,
2091
+ centroid,
2092
+ view_name,
2093
+ bbox_size=bbox_size,
2094
+ )
2095
+
2096
+ # Record the scene (this can also fail with std::bad_alloc)
2097
+ try:
2098
+ window.record(
2099
+ scene=scene,
2100
+ out_path=str(output_image),
2101
+ size=figure_size,
2102
+ )
2103
+ generated_views[view_name] = output_image
2104
+ except RuntimeError as e:
2105
+ # Catch std::bad_alloc and other VTK errors during rendering
2106
+ error_msg = str(e).lower()
2107
+ if "bad_alloc" in error_msg or "memory" in error_msg or "allocation" in error_msg:
2108
+ logger.exception(
2109
+ "VTK memory allocation failed during rendering for view %s (likely std::bad_alloc). "
2110
+ "Tract has %d streamlines with %d total points. "
2111
+ "Try reducing image resolution (figure_size) or processing with n_jobs=1.",
2112
+ view_name,
2113
+ len(tract_streamlines),
2114
+ sum(len(sl) for sl in tract_streamlines),
2115
+ )
2116
+ # Clean up and skip this view
2117
+ scene.clear()
2118
+ continue
2119
+ raise
2120
+
2121
+ scene.clear()
2122
+
2123
+ # Also create the profile line plot (if it doesn't already exist)
2124
+ if not profile_plot_path.exists():
2125
+ fig, ax = plt.subplots(1, figsize=(8, 6))
2126
+ ax.plot(profile)
2127
+ ax.set_xlabel("Node along tract")
2128
+ ax.set_ylabel(metric_str)
2129
+ ax.set_title(f"AFQ Profile: {metric_str}")
2130
+ ax.grid(visible=True, alpha=0.3)
2131
+ fig.savefig(str(profile_plot_path), dpi=150, bbox_inches="tight")
2132
+ plt.close(fig)
2133
+ else:
2134
+ logger.debug("Skipping generation of %s (file already exists)", profile_plot_path)
2135
+
2136
+ except TractographyVisualizationError:
2137
+ raise
2138
+ except (OSError, ValueError, RuntimeError) as e:
2139
+ raise TractographyVisualizationError(
2140
+ f"Failed to plot AFQ profile: {e}",
2141
+ ) from e
2142
+ else:
2143
+ return generated_views
2144
+
2145
+ def calculate_shape_similarity(
2146
+ self,
2147
+ tract_file: str | Path,
2148
+ atlas_file: str | Path,
2149
+ *,
2150
+ atlas_ref_img: str | Path | None = None,
2151
+ flip_lr: bool = False,
2152
+ clust_thr: tuple[float, float, float] = (5, 3, 1.5),
2153
+ threshold: float = 6,
2154
+ rng: np.random.Generator | None = None,
2155
+ ) -> float:
2156
+ """Calculate shape similarity score between tract and atlas.
2157
+
2158
+ Uses DIPY's bundle_shape_similarity function with the Bundle Adjacency (BA) metric
2159
+ to compute how closely the shapes of two bundles align.
2160
+
2161
+ Parameters
2162
+ ----------
2163
+ tract_file : str | Path
2164
+ Path to the subject tractography file.
2165
+ atlas_file : str | Path
2166
+ Path to the atlas tractography file.
2167
+ atlas_ref_img : str | Path | None, optional
2168
+ Path to the reference image that matches the atlas coordinate space
2169
+ (e.g., MNI template if atlas is in MNI space). If None, assumes atlas
2170
+ is in the same space as the subject tract. This is important if the
2171
+ atlas is in a different space (e.g., MNI) than the subject.
2172
+ flip_lr : bool, optional
2173
+ Whether to flip left-right (X-axis) when transforming atlas.
2174
+ This may be needed for some coordinate conventions or file formats
2175
+ where the left-right orientation differs. Default is False.
2176
+ clust_thr : tuple[float, float, float], optional
2177
+ Clustering thresholds for QuickBundlesX used internally.
2178
+ Default is (5, 3, 1.5).
2179
+ threshold : float, optional
2180
+ Threshold controlling the strictness of the shape similarity assessment.
2181
+ A smaller threshold requires the bundles to be more similar to achieve
2182
+ a higher score. Default is 6.
2183
+ rng : np.random.Generator | None, optional
2184
+ Random number generator. If None, creates a new one. Default is None.
2185
+
2186
+ Returns
2187
+ -------
2188
+ float
2189
+ Bundle similarity score (BA value). Higher values indicate greater
2190
+ similarity between the bundles.
2191
+
2192
+ Raises
2193
+ ------
2194
+ FileNotFoundError
2195
+ If required files are not found.
2196
+ InvalidInputError
2197
+ If tracts are empty or invalid.
2198
+ TractographyVisualizationError
2199
+ If calculation fails.
2200
+
2201
+ Examples
2202
+ --------
2203
+ >>> visualizer = TractographyVisualizer()
2204
+ >>> score = visualizer.calculate_shape_similarity(
2205
+ ... "subject_tract.trk", "atlas_tract.trk"
2206
+ ... )
2207
+ >>> print(f"Shape similarity score: {score}")
2208
+ """
2209
+ try:
2210
+ # Load both tracts
2211
+ tract = load_trk(str(tract_file), "same", bbox_valid_check=False)
2212
+ tract.to_rasmm()
2213
+
2214
+ atlas_tract = load_trk(str(atlas_file), "same", bbox_valid_check=False)
2215
+ atlas_tract.to_rasmm()
2216
+
2217
+ # Check if tracts are empty
2218
+ if not tract.streamlines or len(tract.streamlines) == 0:
2219
+ raise InvalidInputError("Subject tract is empty.")
2220
+ if not atlas_tract.streamlines or len(atlas_tract.streamlines) == 0:
2221
+ raise InvalidInputError("Atlas tract is empty.")
2222
+
2223
+ # Transform atlas to subject space if needed
2224
+ if atlas_ref_img is not None:
2225
+ atlas_ref_img_obj = nib.load(str(atlas_ref_img))
2226
+
2227
+ # Transform atlas streamlines to subject reference space
2228
+ atlas_streamlines = transform_streamlines(
2229
+ atlas_tract.streamlines,
2230
+ np.linalg.inv(atlas_ref_img_obj.affine), # type: ignore[attr-defined]
2231
+ )
2232
+
2233
+ # Apply left-right flip if needed
2234
+ if flip_lr:
2235
+ atlas_streamlines = Streamlines(
2236
+ [np.column_stack([-sl[:, 0], sl[:, 1], sl[:, 2]]) for sl in atlas_streamlines],
2237
+ )
2238
+ else:
2239
+ # No transformation needed, use streamlines directly
2240
+ atlas_streamlines = atlas_tract.streamlines
2241
+ if flip_lr:
2242
+ atlas_streamlines = Streamlines(
2243
+ [np.column_stack([-sl[:, 0], sl[:, 1], sl[:, 2]]) for sl in atlas_streamlines],
2244
+ )
2245
+
2246
+ # Use subject tract streamlines directly (already in RASMM)
2247
+ subject_streamlines = tract.streamlines
2248
+
2249
+ # Create random number generator if not provided
2250
+ if rng is None:
2251
+ rng = np.random.default_rng()
2252
+
2253
+ # Calculate shape similarity using DIPY's function
2254
+ similarity_score = bundle_shape_similarity(
2255
+ subject_streamlines,
2256
+ atlas_streamlines,
2257
+ rng,
2258
+ clust_thr=clust_thr,
2259
+ threshold=threshold,
2260
+ )
2261
+
2262
+ return float(similarity_score)
2263
+ except TractographyVisualizationError:
2264
+ raise
2265
+ except (OSError, ValueError, RuntimeError, IndexError) as e:
2266
+ raise TractographyVisualizationError(
2267
+ f"Failed to calculate shape similarity: {e}",
2268
+ ) from e
2269
+
2270
+ def visualize_shape_similarity(
2271
+ self,
2272
+ tract_file: str | Path,
2273
+ atlas_file: str | Path,
2274
+ *,
2275
+ atlas_ref_img: str | Path | None = None,
2276
+ flip_lr: bool = False,
2277
+ views: list[str] | None = None,
2278
+ output_dir: str | Path | None = None,
2279
+ figure_size: tuple[int, int] = (800, 800),
2280
+ show_glass_brain: bool = True,
2281
+ subject_color: tuple[float, float, float] = (1.0, 0.0, 0.0), # Red
2282
+ atlas_color: tuple[float, float, float] = (0.0, 0.0, 1.0), # Blue
2283
+ ) -> dict[str, Path]:
2284
+ """Visualize shape similarity by overlaying subject and atlas tracts.
2285
+
2286
+ Generates anatomical views (coronal, axial, sagittal) showing both tracts
2287
+ overlaid with different colors to visualize their shape similarity.
2288
+
2289
+ Parameters
2290
+ ----------
2291
+ tract_file : str | Path
2292
+ Path to the subject tractography file.
2293
+ atlas_file : str | Path
2294
+ Path to the atlas tractography file.
2295
+ atlas_ref_img : str | Path | None, optional
2296
+ Path to the reference image that matches the atlas coordinate space
2297
+ (e.g., MNI template if atlas is in MNI space).
2298
+ flip_lr : bool, optional
2299
+ Whether to flip left-right (X-axis) when transforming atlas.
2300
+ Default is False.
2301
+ views : list[str] | None, optional
2302
+ List of views to generate. Options: "coronal", "axial", "sagittal".
2303
+ If None, generates all three views. Default is None.
2304
+ output_dir : str | Path | None, optional
2305
+ Output directory for generated images. If None, uses the output
2306
+ directory set during initialization.
2307
+ figure_size : tuple[int, int], optional
2308
+ Size of the output images in pixels. Default is (800, 800).
2309
+ show_glass_brain : bool, optional
2310
+ Whether to show the glass brain outline. Default is True.
2311
+ subject_color : tuple[float, float, float], optional
2312
+ RGB color for subject tract (0-1 range). Default is (1.0, 0.0, 0.0) (red).
2313
+ atlas_color : tuple[float, float, float], optional
2314
+ RGB color for atlas tract (0-1 range). Default is (0.0, 0.0, 1.0) (blue).
2315
+
2316
+ Returns
2317
+ -------
2318
+ dict[str, Path]
2319
+ Dictionary mapping view names to their output file paths.
2320
+ Keys: "coronal", "axial", "sagittal".
2321
+
2322
+ Raises
2323
+ ------
2324
+ FileNotFoundError
2325
+ If required files are not found.
2326
+ InvalidInputError
2327
+ If no output directory is available or invalid view name.
2328
+ TractographyVisualizationError
2329
+ If visualization fails.
2330
+
2331
+ Examples
2332
+ --------
2333
+ >>> visualizer = TractographyVisualizer(
2334
+ ... reference_image="t1w.nii.gz", output_directory="output/"
2335
+ ... )
2336
+ >>> views = visualizer.visualize_shape_similarity(
2337
+ ... "subject_tract.trk", "atlas_tract.trk"
2338
+ ... )
2339
+ """
2340
+ # Use standard anatomical view angles from utils
2341
+ view_angles = ANATOMICAL_VIEW_ANGLES
2342
+
2343
+ # Determine which views to generate
2344
+ if views is None:
2345
+ views_to_generate = list(view_angles.keys())
2346
+ else:
2347
+ views_to_generate = views
2348
+ invalid_views = [v for v in views_to_generate if v not in view_angles]
2349
+ if invalid_views:
2350
+ raise InvalidInputError(
2351
+ f"Invalid view names: {invalid_views}. Valid options: {list(view_angles.keys())}",
2352
+ )
2353
+
2354
+ # Get output directory
2355
+ if output_dir is None:
2356
+ if self._output_directory is None:
2357
+ raise InvalidInputError(
2358
+ "No output directory provided. Set it via constructor or "
2359
+ "set_output_directory() method, or pass it as an argument.",
2360
+ )
2361
+ output_dir = self._output_directory
2362
+ else:
2363
+ output_dir = Path(output_dir)
2364
+ output_dir.mkdir(parents=True, exist_ok=True)
2365
+
2366
+ tract_name = self._extract_tract_name(tract_file)
2367
+ atlas_name = self._extract_tract_name(atlas_file)
2368
+ generated_views: dict[str, Path] = {}
2369
+
2370
+ # Check if all output files already exist BEFORE loading tracts
2371
+ # This prevents unnecessary memory usage when files are already generated
2372
+ all_files_exist = True
2373
+ for view_name in views_to_generate:
2374
+ output_image = output_dir / f"similarity_{tract_name}_vs_{atlas_name}_{view_name}.png"
2375
+ if output_image.exists():
2376
+ generated_views[view_name] = output_image
2377
+ logger.debug("Skipping generation of %s (file already exists)", output_image)
2378
+ else:
2379
+ all_files_exist = False
2380
+
2381
+ # If all files exist, return early without loading tracts
2382
+ if all_files_exist:
2383
+ return generated_views
2384
+
2385
+ try:
2386
+ # Load both tracts only if we need to generate at least one view
2387
+ tract = load_trk(str(tract_file), "same", bbox_valid_check=False)
2388
+ tract.to_rasmm()
2389
+
2390
+ atlas_tract = load_trk(str(atlas_file), "same", bbox_valid_check=False)
2391
+ atlas_tract.to_rasmm()
2392
+
2393
+ # Check if tracts are empty
2394
+ if not tract.streamlines or len(tract.streamlines) == 0:
2395
+ raise InvalidInputError("Subject tract is empty.")
2396
+ if not atlas_tract.streamlines or len(atlas_tract.streamlines) == 0:
2397
+ raise InvalidInputError("Atlas tract is empty.")
2398
+
2399
+ # Load reference images and transform tracts to same space
2400
+ atlas_ref_img_obj = nib.load(str(atlas_ref_img))
2401
+
2402
+ # Transform subject tract to reference space
2403
+ subject_streamlines = transform_streamlines(
2404
+ tract.streamlines,
2405
+ np.linalg.inv(atlas_ref_img_obj.affine), # type: ignore[attr-defined]
2406
+ )
2407
+
2408
+ # No transformation needed, use streamlines directly
2409
+ atlas_streamlines = transform_streamlines(
2410
+ atlas_tract.streamlines,
2411
+ np.linalg.inv(atlas_ref_img_obj.affine), # type: ignore[attr-defined]
2412
+ )
2413
+ if flip_lr:
2414
+ atlas_streamlines = Streamlines(
2415
+ [np.column_stack([-sl[:, 0], sl[:, 1], sl[:, 2]]) for sl in atlas_streamlines],
2416
+ )
2417
+
2418
+ # Calculate combined centroid for rotation (from both tracts)
2419
+ # Calculate combined centroid using utility function
2420
+ centroid = calculate_combined_centroid(subject_streamlines, atlas_streamlines)
2421
+
2422
+ # Generate each requested view
2423
+ for view_name in views_to_generate:
2424
+ output_image = output_dir / f"similarity_{tract_name}_vs_{atlas_name}_{view_name}.png"
2425
+
2426
+ # Skip if file already exists (already added to generated_views above)
2427
+ if output_image.exists():
2428
+ continue
2429
+
2430
+ # Create scene using helper method
2431
+ scene, brain_actor = self._create_scene(ref_img=atlas_ref_img, show_glass_brain=show_glass_brain)
2432
+
2433
+ # Add subject tract with subject color (single color for all streamlines)
2434
+ # Use original streamlines - camera handles view
2435
+ subject_colors = np.tile(subject_color, (len(subject_streamlines), 1))
2436
+ subject_actor = actor.line(subject_streamlines, colors=subject_colors)
2437
+ scene.add(subject_actor)
2438
+
2439
+ # Add atlas tract with atlas color (single color for all streamlines)
2440
+ # Use original streamlines - camera handles view
2441
+ atlas_colors = np.tile(atlas_color, (len(atlas_streamlines), 1))
2442
+ atlas_actor = actor.line(atlas_streamlines, colors=atlas_colors)
2443
+ scene.add(atlas_actor)
2444
+
2445
+ # Set camera position for anatomical view
2446
+ # Use combined bbox of both tracts for camera distance
2447
+ # Calculate combined bbox using utility function
2448
+ bbox_size = calculate_combined_bbox_size(subject_streamlines, atlas_streamlines)
2449
+
2450
+ # Use helper method to set camera
2451
+ self._set_anatomical_camera(
2452
+ scene,
2453
+ centroid,
2454
+ view_name,
2455
+ bbox_size=bbox_size,
2456
+ )
2457
+
2458
+ # Record the scene (this can also fail with std::bad_alloc)
2459
+ try:
2460
+ window.record(
2461
+ scene=scene,
2462
+ out_path=str(output_image),
2463
+ size=figure_size,
2464
+ )
2465
+ generated_views[view_name] = output_image
2466
+ except RuntimeError as e:
2467
+ # Catch std::bad_alloc and other VTK errors during rendering
2468
+ error_msg = str(e).lower()
2469
+ if "bad_alloc" in error_msg or "memory" in error_msg or "allocation" in error_msg:
2470
+ logger.exception(
2471
+ "VTK memory allocation failed during rendering for view %s (likely std::bad_alloc). "
2472
+ "Subject has %d streamlines, atlas has %d streamlines. "
2473
+ "Try reducing image resolution (figure_size) or processing with n_jobs=1.",
2474
+ view_name,
2475
+ len(subject_streamlines),
2476
+ len(atlas_streamlines),
2477
+ )
2478
+ # Clean up and skip this view
2479
+ scene.clear()
2480
+ del subject_actor, atlas_actor
2481
+ if show_glass_brain:
2482
+ del brain_actor
2483
+ gc.collect()
2484
+ continue
2485
+ raise
2486
+
2487
+ # Clean up
2488
+ scene.clear()
2489
+ del subject_actor, atlas_actor
2490
+ if show_glass_brain:
2491
+ del brain_actor
2492
+ gc.collect()
2493
+
2494
+ generated_views[view_name] = output_image
2495
+
2496
+ except TractographyVisualizationError:
2497
+ raise
2498
+ except (OSError, ValueError, RuntimeError) as e:
2499
+ raise TractographyVisualizationError(
2500
+ f"Failed to visualize shape similarity: {e}",
2501
+ ) from e
2502
+ else:
2503
+ return generated_views
2504
+
2505
+ def compare_before_after_cci(
2506
+ self,
2507
+ tract_file: str | Path,
2508
+ ref_img: str | Path | None = None,
2509
+ *,
2510
+ views: list[str] | None = None,
2511
+ output_dir: str | Path | None = None,
2512
+ figure_size: tuple[int, int] = (800, 800),
2513
+ show_glass_brain: bool = True,
2514
+ before_color: tuple[float, float, float] = (0.7, 0.7, 0.7), # Light gray
2515
+ after_color: tuple[float, float, float] = (0.0, 0.0, 1.0), # Blue
2516
+ ) -> dict[str, Path]:
2517
+ """Compare tract before and after CCI filtering.
2518
+
2519
+ Generates side-by-side anatomical views (coronal, axial, sagittal) showing
2520
+ the tract before CCI filtering (left) and after CCI filtering (right).
2521
+
2522
+ Parameters
2523
+ ----------
2524
+ tract_file : str | Path
2525
+ Path to the tractography file.
2526
+ ref_img : str | Path | None, optional
2527
+ Path to the reference image. If None, uses the reference image
2528
+ set during initialization.
2529
+ views : list[str] | None, optional
2530
+ List of views to generate. Options: "coronal", "axial", "sagittal".
2531
+ If None, generates all three views. Default is None.
2532
+ output_dir : str | Path | None, optional
2533
+ Output directory for generated images. If None, uses the output
2534
+ directory set during initialization.
2535
+ figure_size : tuple[int, int], optional
2536
+ Size of each side of the comparison image in pixels. The final image
2537
+ will be twice as wide. Default is (800, 800).
2538
+ show_glass_brain : bool, optional
2539
+ Whether to show the glass brain outline. Default is True.
2540
+ before_color : tuple[float, float, float], optional
2541
+ RGB color for before CCI filtering tract (0-1 range).
2542
+ Default is (0.7, 0.7, 0.7) (light gray).
2543
+ after_color : tuple[float, float, float], optional
2544
+ RGB color for after CCI filtering tract (0-1 range).
2545
+ Default is (0.0, 0.0, 1.0) (blue).
2546
+
2547
+ Returns
2548
+ -------
2549
+ dict[str, Path]
2550
+ Dictionary mapping view names to their output file paths.
2551
+ Keys: "coronal", "axial", "sagittal".
2552
+
2553
+ Raises
2554
+ ------
2555
+ FileNotFoundError
2556
+ If required files are not found.
2557
+ InvalidInputError
2558
+ If no output directory is available or invalid view name.
2559
+ TractographyVisualizationError
2560
+ If comparison fails.
2561
+
2562
+ Examples
2563
+ --------
2564
+ >>> visualizer = TractographyVisualizer(
2565
+ ... reference_image="t1w.nii.gz", output_directory="output/"
2566
+ ... )
2567
+ >>> views = visualizer.compare_before_after_cci("tract.trk")
2568
+ """
2569
+ # Use standard anatomical view angles from utils
2570
+ view_angles = ANATOMICAL_VIEW_ANGLES
2571
+
2572
+ # Determine which views to generate
2573
+ if views is None:
2574
+ views_to_generate = list(view_angles.keys())
2575
+ else:
2576
+ views_to_generate = views
2577
+ invalid_views = [v for v in views_to_generate if v not in view_angles]
2578
+ if invalid_views:
2579
+ raise InvalidInputError(
2580
+ f"Invalid view names: {invalid_views}. Valid options: {list(view_angles.keys())}",
2581
+ )
2582
+
2583
+ # Get output directory
2584
+ if output_dir is None:
2585
+ if self._output_directory is None:
2586
+ raise InvalidInputError(
2587
+ "No output directory provided. Set it via constructor or "
2588
+ "set_output_directory() method, or pass it as an argument.",
2589
+ )
2590
+ output_dir = self._output_directory
2591
+ else:
2592
+ output_dir = Path(output_dir)
2593
+ output_dir.mkdir(parents=True, exist_ok=True)
2594
+
2595
+ if ref_img is None:
2596
+ if self._reference_image is None:
2597
+ raise InvalidInputError(
2598
+ "No reference image provided. Set it via constructor or "
2599
+ "set_reference_image() method, or pass it as an argument.",
2600
+ )
2601
+ ref_img = self._reference_image
2602
+ else:
2603
+ ref_img = Path(ref_img)
2604
+
2605
+ tract_name = self._extract_tract_name(tract_file)
2606
+ generated_views: dict[str, Path] = {}
2607
+
2608
+ # Check if all output files already exist BEFORE loading tract
2609
+ # This prevents unnecessary memory usage when files are already generated
2610
+ all_files_exist = True
2611
+ for view_name in views_to_generate:
2612
+ output_image = output_dir / f"cci_before_after_{tract_name}_{view_name}.png"
2613
+ if output_image.exists():
2614
+ generated_views[view_name] = output_image
2615
+ logger.debug("Skipping generation of %s (file already exists)", output_image)
2616
+ else:
2617
+ all_files_exist = False
2618
+
2619
+ # If all files exist, return early without loading tract
2620
+ if all_files_exist:
2621
+ return generated_views
2622
+
2623
+ try:
2624
+ # Load tract only if we need to generate at least one view
2625
+ tract = load_trk(str(tract_file), "same", bbox_valid_check=False)
2626
+ tract.to_rasmm()
2627
+
2628
+ # Calculate CCI and get filtered tract
2629
+ _, filtered_tract = self.calc_cci(tract)
2630
+
2631
+ # Transform both tracts to reference space
2632
+ ref_img_obj = nib.load(str(ref_img))
2633
+
2634
+ # Before CCI filtering: use all streamlines (after length filtering)
2635
+ # We need to get the streamlines that were used for CCI calculation
2636
+ # (i.e., those longer than min_streamline_length)
2637
+ lengths = list(length(tract.streamlines))
2638
+ before_streamlines = Streamlines()
2639
+ for i, sl in enumerate(tract.streamlines):
2640
+ if lengths[i] > self.min_streamline_length:
2641
+ before_streamlines.append(sl)
2642
+
2643
+ before_streamlines = transform_streamlines(
2644
+ before_streamlines,
2645
+ np.linalg.inv(ref_img_obj.affine), # type: ignore[attr-defined]
2646
+ )
2647
+
2648
+ # After CCI filtering: use filtered tract
2649
+ after_streamlines = transform_streamlines(
2650
+ filtered_tract.streamlines,
2651
+ np.linalg.inv(ref_img_obj.affine), # type: ignore[attr-defined]
2652
+ )
2653
+
2654
+ # Calculate combined centroid for rotation (from both tracts)
2655
+ # Calculate combined centroid using utility function
2656
+ centroid = calculate_combined_centroid(before_streamlines, after_streamlines)
2657
+
2658
+ # Generate each requested view
2659
+ for view_name in views_to_generate:
2660
+ output_image = output_dir / f"cci_before_after_{tract_name}_{view_name}.png"
2661
+
2662
+ # Skip if file already exists (already added to generated_views above)
2663
+ if output_image.exists():
2664
+ continue
2665
+
2666
+ # Create side-by-side scenes using helper methods
2667
+ # Left side: Before CCI filtering
2668
+ scene_before, brain_actor_before = self._create_scene(
2669
+ ref_img=ref_img,
2670
+ show_glass_brain=show_glass_brain,
2671
+ )
2672
+
2673
+ # Add before tract (use original streamlines - camera handles view)
2674
+ before_colors = np.tile(before_color, (len(before_streamlines), 1))
2675
+ before_actor = actor.line(before_streamlines, colors=before_colors)
2676
+ scene_before.add(before_actor)
2677
+
2678
+ # Set camera for before scene using utility function
2679
+ bbox_size_before = calculate_bbox_size(before_streamlines)
2680
+ self._set_anatomical_camera(
2681
+ scene_before,
2682
+ centroid,
2683
+ view_name,
2684
+ bbox_size=bbox_size_before,
2685
+ )
2686
+
2687
+ # Right side: After CCI filtering
2688
+ scene_after, brain_actor_after = self._create_scene(ref_img=ref_img, show_glass_brain=show_glass_brain)
2689
+
2690
+ # Add after tract (use original streamlines - camera handles view)
2691
+ after_colors = np.tile(after_color, (len(after_streamlines), 1))
2692
+ after_actor = actor.line(after_streamlines, colors=after_colors)
2693
+ scene_after.add(after_actor)
2694
+
2695
+ # Set camera for after scene using utility function
2696
+ bbox_size_after = calculate_bbox_size(after_streamlines)
2697
+ self._set_anatomical_camera(
2698
+ scene_after,
2699
+ centroid,
2700
+ view_name,
2701
+ bbox_size=bbox_size_after,
2702
+ )
2703
+
2704
+ # Record both scenes to temporary files
2705
+
2706
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_before:
2707
+ tmp_before_path = tmp_before.name
2708
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_after:
2709
+ tmp_after_path = tmp_after.name
2710
+
2711
+ # Record both scenes (this can also fail with std::bad_alloc)
2712
+ try:
2713
+ window.record(
2714
+ scene=scene_before,
2715
+ out_path=tmp_before_path,
2716
+ size=figure_size,
2717
+ )
2718
+ window.record(
2719
+ scene=scene_after,
2720
+ out_path=tmp_after_path,
2721
+ size=figure_size,
2722
+ )
2723
+ except RuntimeError as e:
2724
+ # Catch std::bad_alloc and other VTK errors during rendering
2725
+ error_msg = str(e).lower()
2726
+ if "bad_alloc" in error_msg or "memory" in error_msg or "allocation" in error_msg:
2727
+ logger.exception(
2728
+ "VTK memory allocation failed during rendering for view %s (likely std::bad_alloc). "
2729
+ "Before has %d streamlines, after has %d streamlines. "
2730
+ "Try reducing image resolution (figure_size) or processing with n_jobs=1.",
2731
+ view_name,
2732
+ len(before_streamlines),
2733
+ len(after_streamlines),
2734
+ )
2735
+ # Clean up and skip this view
2736
+ scene_before.clear()
2737
+ scene_after.clear()
2738
+ del before_actor, after_actor
2739
+ if brain_actor_before is not None:
2740
+ del brain_actor_before
2741
+ if brain_actor_after is not None:
2742
+ del brain_actor_after
2743
+ del scene_before, scene_after
2744
+ gc.collect()
2745
+ # Clean up temp files if they were created
2746
+ with contextlib.suppress(OSError):
2747
+ os.unlink(tmp_before_path)
2748
+ os.unlink(tmp_after_path)
2749
+ continue
2750
+ raise
2751
+
2752
+ # Combine images side-by-side using PIL/Pillow or imageio
2753
+ try:
2754
+ img_before = Image.open(tmp_before_path)
2755
+ img_after = Image.open(tmp_after_path)
2756
+
2757
+ # Create combined image (twice as wide)
2758
+ combined_width = figure_size[0] * 2
2759
+ combined_height = figure_size[1]
2760
+ combined_img = Image.new("RGB", (combined_width, combined_height), (255, 255, 255))
2761
+ combined_img.paste(img_before, (0, 0))
2762
+ combined_img.paste(img_after, (figure_size[0], 0))
2763
+ combined_img.save(str(output_image))
2764
+
2765
+ img_before.close()
2766
+ img_after.close()
2767
+ except ImportError:
2768
+ # Fallback to imageio if PIL not available
2769
+
2770
+ img_before = imageio.imread(tmp_before_path)
2771
+ img_after = imageio.imread(tmp_after_path)
2772
+ combined_img = np.concatenate([img_before, img_after], axis=1)
2773
+ imageio.imwrite(str(output_image), combined_img)
2774
+
2775
+ # Clean up temporary files
2776
+ os.unlink(tmp_before_path)
2777
+ os.unlink(tmp_after_path)
2778
+
2779
+ # Clean up scenes
2780
+ scene_before.clear()
2781
+ scene_after.clear()
2782
+ del before_actor, after_actor
2783
+ if show_glass_brain:
2784
+ del brain_actor_before, brain_actor_after
2785
+ gc.collect()
2786
+
2787
+ generated_views[view_name] = output_image
2788
+
2789
+ except TractographyVisualizationError:
2790
+ raise
2791
+ except (OSError, ValueError, RuntimeError) as e:
2792
+ raise TractographyVisualizationError(
2793
+ f"Failed to compare before/after CCI: {e}",
2794
+ ) from e
2795
+ else:
2796
+ return generated_views
2797
+
2798
+ def visualize_bundle_assignment(
2799
+ self,
2800
+ tract_file: str | Path,
2801
+ atlas_file: str | Path,
2802
+ ref_img: str | Path | None = None,
2803
+ *,
2804
+ n_segments: int = 100,
2805
+ views: list[str] | None = None,
2806
+ output_dir: str | Path | None = None,
2807
+ figure_size: tuple[int, int] = (800, 800),
2808
+ show_glass_brain: bool = True,
2809
+ colormap: str = "random",
2810
+ ) -> dict[str, Path]:
2811
+ """Visualize bundle assignment map using DIPY's assignment_map.
2812
+
2813
+ Assigns each streamline in the target tract to a segment of the model
2814
+ bundle using DIPY's assignment_map function. Color-codes streamlines
2815
+ by their assigned segment. Generates anatomical views showing which
2816
+ streamlines belong to which bundle segment.
2817
+
2818
+ Parameters
2819
+ ----------
2820
+ tract_file : str | Path
2821
+ Path to the target tractography file (streamlines to assign).
2822
+ atlas_file : str | Path
2823
+ Path to the atlas tractography file (reference bundle for assignment).
2824
+ ref_img : str | Path | None, optional
2825
+ Path to the reference image. If None, uses the reference image
2826
+ set during initialization.
2827
+ n_segments : int, optional
2828
+ Number of segments to divide the model bundle into. Each streamline
2829
+ in the target tract will be assigned to the closest segment.
2830
+ Default is 100.
2831
+ views : list[str] | None, optional
2832
+ List of views to generate. Options: "coronal", "axial", "sagittal".
2833
+ If None, generates all three views. Default is None.
2834
+ output_dir : str | Path | None, optional
2835
+ Output directory for generated images. If None, uses the output
2836
+ directory set during initialization.
2837
+ figure_size : tuple[int, int], optional
2838
+ Size of the output images in pixels. Default is (800, 800).
2839
+ show_glass_brain : bool, optional
2840
+ Whether to show the glass brain outline. Default is True.
2841
+ colormap : str, optional
2842
+ Name of the colormap to use for segment colors. Should be a
2843
+ discrete colormap (e.g., "tab20", "Set3", "Paired").
2844
+ Default is "tab20".
2845
+
2846
+ Returns
2847
+ -------
2848
+ dict[str, Path]
2849
+ Dictionary mapping view names to their output file paths.
2850
+ Keys: "coronal", "axial", "sagittal".
2851
+
2852
+ Raises
2853
+ ------
2854
+ FileNotFoundError
2855
+ If required files are not found.
2856
+ InvalidInputError
2857
+ If no output directory is available or invalid view name.
2858
+ TractographyVisualizationError
2859
+ If visualization fails.
2860
+
2861
+ Examples
2862
+ --------
2863
+ >>> visualizer = TractographyVisualizer(
2864
+ ... reference_image="t1w.nii.gz", output_directory="output/"
2865
+ ... )
2866
+ >>> views = visualizer.visualize_bundle_assignment(
2867
+ ... "target_tract.trk", "model_tract.trk", n_segments=100
2868
+ ... )
2869
+ """
2870
+ # Use standard anatomical view angles from utils
2871
+ view_angles = ANATOMICAL_VIEW_ANGLES
2872
+ rgb = (2, 3)
2873
+
2874
+ # Determine which views to generate
2875
+ if views is None:
2876
+ views_to_generate = list(view_angles.keys())
2877
+ else:
2878
+ views_to_generate = views
2879
+ invalid_views = [v for v in views_to_generate if v not in view_angles]
2880
+ if invalid_views:
2881
+ raise InvalidInputError(
2882
+ f"Invalid view names: {invalid_views}. Valid options: {list(view_angles.keys())}",
2883
+ )
2884
+
2885
+ # Get output directory
2886
+ if output_dir is None:
2887
+ if self._output_directory is None:
2888
+ raise InvalidInputError(
2889
+ "No output directory provided. Set it via constructor or "
2890
+ "set_output_directory() method, or pass it as an argument.",
2891
+ )
2892
+ output_dir = self._output_directory
2893
+ else:
2894
+ output_dir = Path(output_dir)
2895
+ output_dir.mkdir(parents=True, exist_ok=True)
2896
+
2897
+ if ref_img is None:
2898
+ if self._reference_image is None:
2899
+ raise InvalidInputError(
2900
+ "No reference image provided. Set it via constructor or "
2901
+ "set_reference_image() method, or pass it as an argument.",
2902
+ )
2903
+ ref_img = self._reference_image
2904
+ else:
2905
+ ref_img = Path(ref_img)
2906
+
2907
+ tract_name = self._extract_tract_name(tract_file)
2908
+ generated_views: dict[str, Path] = {}
2909
+
2910
+ # Check if all output files already exist BEFORE loading tracts
2911
+ # This prevents unnecessary memory usage when files are already generated
2912
+ all_files_exist = True
2913
+ for view_name in views_to_generate:
2914
+ output_image = output_dir / f"bundle_assignment_{tract_name}_{view_name}.png"
2915
+ if output_image.exists():
2916
+ generated_views[view_name] = output_image
2917
+ logger.debug("Skipping generation of %s (file already exists)", output_image)
2918
+ else:
2919
+ all_files_exist = False
2920
+
2921
+ # If all files exist, return early without loading tracts
2922
+ if all_files_exist:
2923
+ return generated_views
2924
+
2925
+ try:
2926
+ # Load both tracts only if we need to generate at least one view
2927
+ tract = load_trk(str(tract_file), "same", bbox_valid_check=False)
2928
+ tract.to_rasmm()
2929
+
2930
+ atlas_tract = load_trk(str(atlas_file), "same", bbox_valid_check=False)
2931
+ atlas_tract.to_rasmm()
2932
+
2933
+ if not tract.streamlines or len(tract.streamlines) == 0:
2934
+ raise InvalidInputError("Target tractogram is empty.")
2935
+ if not atlas_tract.streamlines or len(atlas_tract.streamlines) == 0:
2936
+ raise InvalidInputError("Atlas tractogram is empty.")
2937
+
2938
+ # Transform both tracts to reference space for assignment
2939
+ # assignment_map requires both tracts to be in the same coordinate space
2940
+ ref_img_obj = nib.load(str(ref_img))
2941
+ tract_streamlines = transform_streamlines(
2942
+ tract.streamlines,
2943
+ np.linalg.inv(ref_img_obj.affine), # type: ignore[attr-defined]
2944
+ )
2945
+ atlas_streamlines = transform_streamlines(
2946
+ atlas_tract.streamlines,
2947
+ np.linalg.inv(ref_img_obj.affine), # type: ignore[attr-defined]
2948
+ )
2949
+
2950
+ # Calculate assignments on transformed streamlines (both in same space)
2951
+ # This ensures consistent colors across all rotations
2952
+ # Use assignment_map to assign target streamlines to model bundle segments
2953
+ # assignment_map returns per-point assignments (one assignment per point)
2954
+ # Assignments are in the order points appear when iterating through streamlines sequentially
2955
+ assignment_indices = assignment_map(tract_streamlines, atlas_streamlines, n_segments)
2956
+ assignment_indices = np.array(assignment_indices)
2957
+
2958
+ # Generate colors for each segment
2959
+ # Use random colors like DIPY example (produces wide spectrum of distinct colors)
2960
+ # or use a colormap if specified (e.g., "Spectral", "hsv", "rainbow" for wide spectrum)
2961
+ if colormap == "random" or colormap is None:
2962
+ # Match DIPY example: use random colors for maximum color differentiation
2963
+ # This produces a wide spectrum of distinct colors like the example image
2964
+ rng = np.random.default_rng()
2965
+ segment_colors = [tuple(rng.random(3)) for si in range(n_segments)]
2966
+ else:
2967
+ # Use specified colormap
2968
+ # For wide spectrum, consider: "Spectral", "hsv", "rainbow", "turbo"
2969
+ segment_colors_array = create_colormap(
2970
+ np.linspace(0, 1, n_segments),
2971
+ name=colormap,
2972
+ )
2973
+ # Convert to RGB (0-1 range), remove alpha if present
2974
+ segment_colors_array = (
2975
+ segment_colors_array[:, : rgb[1]]
2976
+ if segment_colors_array.shape[1] > rgb[1]
2977
+ else segment_colors_array
2978
+ )
2979
+ # Convert to list of tuples for compatibility
2980
+ segment_colors = [tuple(segment_colors_array[i]) for i in range(n_segments)]
2981
+
2982
+ # Create per-point colors based on assignment (matching DIPY example pattern)
2983
+ # Convert to list of tuples as in DIPY example to ensure proper color application
2984
+ # This is calculated once and reused for all rotations
2985
+ # Each point gets the color corresponding to its segment assignment
2986
+ # This creates the banding effect as points along a streamline are assigned to different segments
2987
+ # IMPORTANT: assignment_indices are in the order points appear when iterating through
2988
+ # streamlines sequentially (streamline 0 all points, then streamline 1 all points, etc.)
2989
+ point_colors = []
2990
+ for i in range(len(assignment_indices)):
2991
+ # Get the color for this point's assigned segment
2992
+ seg_idx = assignment_indices[i]
2993
+ # Ensure index is within bounds
2994
+ if seg_idx < 0 or seg_idx >= len(segment_colors):
2995
+ # Fallback to first color if index is out of bounds
2996
+ point_colors.append(segment_colors[0])
2997
+ else:
2998
+ point_colors.append(tuple(segment_colors[seg_idx]))
2999
+
3000
+ # Debug: Show sample of assignments and colors for first streamline (after point_colors is created)
3001
+ point_idx = 0
3002
+ first_sl = tract_streamlines[0]
3003
+ first_sl_assignments = assignment_indices[point_idx : point_idx + len(first_sl)]
3004
+ first_sl_colors = point_colors[point_idx : point_idx + len(first_sl)]
3005
+ logger.debug(
3006
+ "First streamline: %d points, %d unique segments",
3007
+ len(first_sl),
3008
+ len(np.unique(first_sl_assignments)),
3009
+ )
3010
+ logger.debug("Sample assignments (first 10): %s", first_sl_assignments[:10])
3011
+ logger.debug("Sample colors (first 3): %s", first_sl_colors[:3])
3012
+
3013
+ # Get centroid using utility function
3014
+ centroid = calculate_centroid(tract_streamlines)
3015
+ gc.collect()
3016
+
3017
+ # Generate each requested view
3018
+ for view_name in views_to_generate:
3019
+ output_image = output_dir / f"bundle_assignment_{tract_name}_{view_name}.png"
3020
+
3021
+ # Skip if file already exists (already added to generated_views above)
3022
+ if output_image.exists():
3023
+ continue
3024
+
3025
+ # Create scene using helper method
3026
+ scene, brain_actor = self._create_scene(ref_img=ref_img, show_glass_brain=show_glass_brain)
3027
+
3028
+ # Convert colors to numpy array format (N x 3) for actor.line
3029
+ point_colors_array = np.array(point_colors, dtype=np.float32)
3030
+
3031
+ # Ensure the array has the correct shape (N points x 3 RGB values)
3032
+ if point_colors_array.ndim != rgb[0] or point_colors_array.shape[1] != rgb[1]:
3033
+
3034
+ def _create_shape_error(error_msg: str) -> ValueError:
3035
+ return ValueError(error_msg)
3036
+
3037
+ msg = (
3038
+ f"point_colors_array has incorrect shape: {point_colors_array.shape}. "
3039
+ f"Expected (N, 3) where N is the number of points."
3040
+ )
3041
+ raise _create_shape_error(msg)
3042
+
3043
+ # Use original streamlines with original colors - no rotation needed
3044
+ tract_actor = actor.line(tract_streamlines, colors=point_colors_array, fake_tube=True, linewidth=6)
3045
+ scene.add(tract_actor)
3046
+
3047
+ # Set camera position for anatomical view (no streamline rotation needed)
3048
+ # Calculate bbox for camera distance using utility function
3049
+ bbox_size = calculate_bbox_size(tract_streamlines)
3050
+
3051
+ # Use helper method to set camera
3052
+ self._set_anatomical_camera(
3053
+ scene,
3054
+ centroid,
3055
+ view_name,
3056
+ bbox_size=bbox_size,
3057
+ )
3058
+
3059
+ # Record the scene (this can also fail with std::bad_alloc)
3060
+ try:
3061
+ window.record(
3062
+ scene=scene,
3063
+ out_path=str(output_image),
3064
+ size=figure_size,
3065
+ )
3066
+ generated_views[view_name] = output_image
3067
+ except RuntimeError as e:
3068
+ # Catch std::bad_alloc and other VTK errors during rendering
3069
+ error_msg = str(e).lower()
3070
+ if "bad_alloc" in error_msg or "memory" in error_msg or "allocation" in error_msg:
3071
+ logger.exception(
3072
+ "VTK memory allocation failed during rendering for view %s (likely std::bad_alloc). "
3073
+ "Tract has %d streamlines with %d total points. "
3074
+ "Try reducing image resolution (figure_size) or processing with n_jobs=1.",
3075
+ view_name,
3076
+ len(tract_streamlines),
3077
+ sum(len(sl) for sl in tract_streamlines),
3078
+ )
3079
+ # Clean up and skip this view
3080
+ scene.clear()
3081
+ del tract_actor
3082
+ if show_glass_brain:
3083
+ del brain_actor
3084
+ gc.collect()
3085
+ continue
3086
+ raise
3087
+
3088
+ # Clean up
3089
+ scene.clear()
3090
+ del tract_actor
3091
+ if show_glass_brain:
3092
+ del brain_actor
3093
+ gc.collect()
3094
+
3095
+ generated_views[view_name] = output_image
3096
+
3097
+ except (OSError, ValueError, RuntimeError, IndexError) as e:
3098
+ raise TractographyVisualizationError(
3099
+ f"Failed to visualize bundle assignment: {e}",
3100
+ ) from e
3101
+ else:
3102
+ return generated_views
3103
+
3104
+ def generate_gif(
3105
+ self,
3106
+ name: str,
3107
+ tract_file: str | Path,
3108
+ ref_img: str | Path | None = None,
3109
+ *,
3110
+ ref_file: str | Path | None = None, # Alias for ref_img
3111
+ output_dir: str | Path | None = None,
3112
+ ) -> Path:
3113
+ """Generate a GIF animation of rotating tractography.
3114
+
3115
+ Parameters
3116
+ ----------
3117
+ name : str
3118
+ Base name for the output GIF file (without extension).
3119
+ tract_file : str | Path
3120
+ Path to the tractography file.
3121
+ ref_img : str | Path | None, optional
3122
+ Path to the reference image. If None, uses the reference image
3123
+ set during initialization.
3124
+ output_dir : str | Path | None, optional
3125
+ Output directory. If None, uses the output directory set during
3126
+ initialization or creates a default directory.
3127
+
3128
+ Returns
3129
+ -------
3130
+ Path
3131
+ Path to the generated GIF file.
3132
+
3133
+ Raises
3134
+ ------
3135
+ FileNotFoundError
3136
+ If required files are not found.
3137
+ InvalidInputError
3138
+ If no output directory is available.
3139
+ TractographyVisualizationError
3140
+ If GIF generation fails.
3141
+ """
3142
+ if ref_img is None:
3143
+ if self._reference_image is None:
3144
+ raise InvalidInputError(
3145
+ "No reference image provided. Set it via constructor or "
3146
+ "set_reference_image() method, or pass it as an argument.",
3147
+ )
3148
+ ref_img = self._reference_image
3149
+ else:
3150
+ ref_img = Path(ref_img)
3151
+
3152
+ if output_dir is None:
3153
+ if self._output_directory is None:
3154
+ raise InvalidInputError(
3155
+ "No output directory provided. Set it via constructor or "
3156
+ "set_output_directory() method, or pass it as an argument.",
3157
+ )
3158
+ output_dir = self._output_directory
3159
+ else:
3160
+ output_dir = Path(output_dir)
3161
+ output_dir.mkdir(parents=True, exist_ok=True)
3162
+
3163
+ gif_filename = output_dir / f"{name}.gif"
3164
+
3165
+ # Skip if file already exists
3166
+ if gif_filename.exists():
3167
+ logger.debug("Skipping generation of %s (file already exists)", gif_filename)
3168
+ return gif_filename
3169
+
3170
+ # Support both ref_img and ref_file (alias) for backward compatibility
3171
+ if ref_img is None and ref_file is not None:
3172
+ ref_img = ref_file
3173
+
3174
+ try:
3175
+ # Load tract and get streamlines for rotation
3176
+ tract = load_trk(str(tract_file), "same", bbox_valid_check=False)
3177
+ tract.to_rasmm()
3178
+ ref_img_obj = nib.load(str(ref_img))
3179
+ tract_streamlines = transform_streamlines(
3180
+ tract.streamlines,
3181
+ np.linalg.inv(ref_img_obj.affine), # type: ignore[attr-defined]
3182
+ )
3183
+
3184
+ # Create initial actors
3185
+ tract_actor = actor.line(tract_streamlines)
3186
+ brain_actor = self.get_glass_brain(ref_img)
3187
+ scene = window.Scene()
3188
+ scene.add(brain_actor)
3189
+ scene.add(tract_actor)
3190
+ scene.setBackground(color=(1, 1, 1))
3191
+
3192
+ angles = np.linspace(0, 360, self.gif_frames, endpoint=False)
3193
+ rotation_axis = np.array([0, 0, 1]) # Rotate around Z-axis
3194
+ rotation_center = np.array([0, 0, 0])
3195
+ gif_frames = []
3196
+
3197
+ for angle in angles:
3198
+ rot_matrix = Rotation.from_rotvec(
3199
+ angle * np.pi / 180 * rotation_axis,
3200
+ ).as_matrix()
3201
+
3202
+ # Rotate streamlines around the origin
3203
+ rotated_streamlines = [
3204
+ np.dot(s - rotation_center, rot_matrix.T) + rotation_center for s in tract_streamlines
3205
+ ]
3206
+ stream_actor = actor.line(rotated_streamlines)
3207
+
3208
+ # Convert to 4x4 transformation matrix
3209
+ transform_matrix = np.eye(4)
3210
+ transform_matrix[:3, :3] = rot_matrix
3211
+ transform_matrix[:3, 3] = rotation_center - np.dot(
3212
+ rot_matrix,
3213
+ rotation_center,
3214
+ )
3215
+
3216
+ # Rotate glass brain using VTK transform
3217
+ transform = vtk.vtkTransform()
3218
+ transform.Concatenate(transform_matrix.flatten())
3219
+
3220
+ brain_actor.SetUserTransform(transform)
3221
+
3222
+ # Clear scene and re-add actors
3223
+ scene.clear()
3224
+ scene.add(stream_actor)
3225
+ scene.add(brain_actor)
3226
+
3227
+ # Snapshot can also fail with std::bad_alloc
3228
+ try:
3229
+ frame = window.snapshot(scene, size=self.gif_size)
3230
+ gif_frames.append(frame)
3231
+ except RuntimeError as e:
3232
+ # Catch std::bad_alloc and other VTK errors during snapshot
3233
+ error_msg = str(e).lower()
3234
+ if "bad_alloc" in error_msg or "memory" in error_msg or "allocation" in error_msg:
3235
+ logger.exception(
3236
+ "VTK memory allocation failed during GIF frame snapshot (likely std::bad_alloc). "
3237
+ "Tract has %d streamlines with %d total points. "
3238
+ "Try reducing gif_size or processing with n_jobs=1.",
3239
+ len(tract_streamlines),
3240
+ sum(len(sl) for sl in tract_streamlines),
3241
+ )
3242
+ # Clean up and break out of loop
3243
+ scene.clear()
3244
+ del stream_actor, rotated_streamlines
3245
+ gc.collect()
3246
+ # If we have some frames, save what we have
3247
+ if gif_frames:
3248
+ logger.warning("Saving partial GIF with %d frames", len(gif_frames))
3249
+ else:
3250
+ # No frames captured, skip this view
3251
+ del gif_frames, tract_actor, brain_actor, scene
3252
+ gc.collect()
3253
+ continue
3254
+ raise
3255
+
3256
+ # Force memory cleanup
3257
+ del stream_actor, rotated_streamlines
3258
+ gc.collect()
3259
+
3260
+ # Clean up scene and actors after all frames are captured
3261
+ scene.clear()
3262
+ del tract_actor, brain_actor, scene
3263
+ del tract, tract_streamlines, ref_img_obj
3264
+
3265
+ # Save as optimized GIF
3266
+ imageio.mimsave(
3267
+ str(gif_filename),
3268
+ gif_frames,
3269
+ duration=self.gif_duration,
3270
+ palettesize=self.gif_palette_size,
3271
+ )
3272
+
3273
+ # Clean up frames after saving
3274
+ del gif_frames
3275
+ gc.collect()
3276
+ except TractographyVisualizationError:
3277
+ raise
3278
+ except (OSError, ValueError, RuntimeError) as e:
3279
+ raise TractographyVisualizationError(
3280
+ f"Failed to generate GIF: {e}",
3281
+ ) from e
3282
+ else:
3283
+ return gif_filename
3284
+
3285
+ def convert_gif_to_mp4(
3286
+ self,
3287
+ gif_path: str | Path,
3288
+ mp4_path: str | Path | None = None,
3289
+ *,
3290
+ fps: int = 10,
3291
+ ) -> Path:
3292
+ """Convert a GIF file to MP4 format.
3293
+
3294
+ Parameters
3295
+ ----------
3296
+ gif_path : str | Path
3297
+ Path to the input GIF file.
3298
+ mp4_path : str | Path | None, optional
3299
+ Path to the output MP4 file. If None, uses the same name as GIF
3300
+ with .mp4 extension.
3301
+ fps : int, optional
3302
+ Frames per second for the video. Default is 10.
3303
+
3304
+ Returns
3305
+ -------
3306
+ Path
3307
+ Path to the generated MP4 file.
3308
+
3309
+ Raises
3310
+ ------
3311
+ FileNotFoundError
3312
+ If the GIF file is not found.
3313
+ TractographyVisualizationError
3314
+ If conversion fails.
3315
+ """
3316
+ gif_path_obj = Path(gif_path)
3317
+ if mp4_path is None:
3318
+ mp4_path = gif_path_obj.with_suffix(".mp4")
3319
+ else:
3320
+ mp4_path = Path(mp4_path)
3321
+ mp4_path.parent.mkdir(parents=True, exist_ok=True)
3322
+
3323
+ # Skip if file already exists
3324
+ if mp4_path.exists():
3325
+ logger.debug("Skipping conversion of %s to %s (file already exists)", gif_path, mp4_path)
3326
+ return mp4_path
3327
+
3328
+ try:
3329
+ reader = imageio.get_reader(str(gif_path_obj))
3330
+ writer = imageio.get_writer(
3331
+ str(mp4_path),
3332
+ format="FFMPEG", # type: ignore[arg-type]
3333
+ fps=fps,
3334
+ codec="libx264",
3335
+ )
3336
+
3337
+ for frame in reader: # type: ignore[attr-defined]
3338
+ writer.append_data(frame)
3339
+
3340
+ writer.close()
3341
+ reader.close()
3342
+ except (OSError, ValueError, RuntimeError) as e:
3343
+ raise TractographyVisualizationError(
3344
+ f"Failed to convert GIF to MP4: {e}",
3345
+ ) from e
3346
+ else:
3347
+ return mp4_path
3348
+
3349
+ def generate_videos(
3350
+ self,
3351
+ tract_files: list[str | Path],
3352
+ ref_img: str | Path | None = None,
3353
+ *,
3354
+ ref_file: str | Path | None = None, # Alias for ref_img
3355
+ output_dir: str | Path | None = None,
3356
+ remove_gifs: bool = True,
3357
+ ) -> dict[str, Path]:
3358
+ """Generate MP4 videos from multiple tractography files.
3359
+
3360
+ Parameters
3361
+ ----------
3362
+ tract_files : list[str | Path]
3363
+ List of paths to tractography files.
3364
+ ref_img : str | Path | None, optional
3365
+ Path to the reference image. If None, uses the reference image
3366
+ set during initialization.
3367
+ output_dir : str | Path | None, optional
3368
+ Output directory. If None, uses the output directory set during
3369
+ initialization.
3370
+ remove_gifs : bool, optional
3371
+ Whether to remove intermediate GIF files. Default is True.
3372
+
3373
+ Returns
3374
+ -------
3375
+ dict[str, Path]
3376
+ Dictionary mapping tract names to their MP4 file paths.
3377
+
3378
+ Raises
3379
+ ------
3380
+ FileNotFoundError
3381
+ If required files are not found.
3382
+ InvalidInputError
3383
+ If the tract_files list is empty.
3384
+ TractographyVisualizationError
3385
+ If video generation fails.
3386
+ """
3387
+ if not tract_files:
3388
+ raise InvalidInputError("No tract files provided.")
3389
+
3390
+ # Support both ref_img and ref_file (alias) for backward compatibility
3391
+ if ref_img is None and ref_file is not None:
3392
+ ref_img = ref_file
3393
+
3394
+ if output_dir is None:
3395
+ if self._output_directory is None:
3396
+ raise InvalidInputError(
3397
+ "No output directory provided. Set it via constructor or "
3398
+ "set_output_directory() method, or pass it as an argument.",
3399
+ )
3400
+ output_dir = self._output_directory
3401
+ else:
3402
+ output_dir = Path(output_dir)
3403
+ output_dir.mkdir(parents=True, exist_ok=True)
3404
+
3405
+ if ref_img is None:
3406
+ if self._reference_image is None:
3407
+ raise InvalidInputError(
3408
+ "No reference image provided. Set it via constructor or "
3409
+ "set_reference_image() method, or pass it as an argument.",
3410
+ )
3411
+ ref_img = self._reference_image
3412
+ else:
3413
+ ref_img = Path(ref_img)
3414
+
3415
+ tract_videos: dict[str, Path] = {}
3416
+
3417
+ for tract_file in tract_files:
3418
+ try:
3419
+ tract_path = Path(tract_file)
3420
+ tract_name = tract_path.stem
3421
+ tract_mp4 = output_dir / f"{tract_name}.mp4"
3422
+
3423
+ # Skip if MP4 already exists
3424
+ if tract_mp4.exists():
3425
+ logger.debug("Skipping generation of %s (file already exists)", tract_mp4)
3426
+ tract_videos[tract_name] = tract_mp4
3427
+ continue
3428
+
3429
+ tract_gif = self.generate_gif(
3430
+ name=tract_name,
3431
+ tract_file=tract_path,
3432
+ ref_img=ref_img,
3433
+ output_dir=output_dir,
3434
+ )
3435
+ self.convert_gif_to_mp4(tract_gif, tract_mp4)
3436
+
3437
+ if remove_gifs:
3438
+ tract_gif.unlink()
3439
+
3440
+ tract_videos[tract_name] = tract_mp4
3441
+ except (OSError, ValueError, RuntimeError) as e:
3442
+ raise TractographyVisualizationError(
3443
+ f"Failed to generate video for {tract_file}: {e}",
3444
+ ) from e
3445
+
3446
+ return tract_videos
3447
+
3448
+ def _process_single_tract(
3449
+ self,
3450
+ subject_id: str,
3451
+ tract_name: str,
3452
+ tract_file: str | Path,
3453
+ subject_ref_img: Path,
3454
+ tract_output_dir: Path,
3455
+ subjects_mni_space: dict[str, dict[str, str | Path]] | None,
3456
+ atlas_files: dict[str, str | Path] | None,
3457
+ metric_files: dict[str, dict[str, str | Path]] | None,
3458
+ atlas_ref_img: str | Path | None,
3459
+ *,
3460
+ flip_lr: bool,
3461
+ skip_checks: list[str],
3462
+ **kwargs: Any,
3463
+ ) -> dict[str, str | Path]:
3464
+ """Process a single subject/tract combination.
3465
+
3466
+ This is a helper method used for parallel processing.
3467
+ Returns a dictionary of results for this tract.
3468
+ """
3469
+ tract_results: dict[str, str | Path] = {}
3470
+
3471
+ try:
3472
+ # 1. Standard anatomical views
3473
+ if "anatomical_views" not in skip_checks:
3474
+ try:
3475
+ anatomical_views = self.generate_anatomical_views(
3476
+ tract_file,
3477
+ output_dir=tract_output_dir,
3478
+ ref_img=subject_ref_img,
3479
+ **kwargs,
3480
+ )
3481
+ # Add anatomical views to results
3482
+ for view_name, view_path in anatomical_views.items():
3483
+ tract_results[f"anatomical_{view_name}"] = view_path
3484
+ except (OSError, ValueError, RuntimeError) as e:
3485
+ logger.warning(
3486
+ "Failed to generate anatomical views for %s/%s: %s",
3487
+ subject_id,
3488
+ tract_name,
3489
+ e,
3490
+ )
3491
+
3492
+ # 2. CCI calculation and visualization
3493
+ if "cci" not in skip_checks:
3494
+ try:
3495
+ cci_plots = self.plot_cci(
3496
+ tract_file, # type: ignore[arg-type]
3497
+ output_dir=tract_output_dir,
3498
+ ref_img=subject_ref_img,
3499
+ **kwargs,
3500
+ )
3501
+ # Add CCI plots to results
3502
+ for plot_type, plot_path in cci_plots.items():
3503
+ tract_results[f"cci_{plot_type}"] = plot_path
3504
+ except (OSError, ValueError, RuntimeError) as e:
3505
+ logger.warning(
3506
+ "Failed to generate CCI plots for %s/%s: %s",
3507
+ subject_id,
3508
+ tract_name,
3509
+ e,
3510
+ )
3511
+
3512
+ # 3. Before/after CCI comparison
3513
+ if "before_after_cci" not in skip_checks:
3514
+ try:
3515
+ before_after_views = self.compare_before_after_cci(
3516
+ tract_file,
3517
+ output_dir=tract_output_dir,
3518
+ ref_img=subject_ref_img,
3519
+ **kwargs,
3520
+ )
3521
+ # Add before/after views to results
3522
+ for view_name, view_path in before_after_views.items():
3523
+ tract_results[f"before_after_cci_{view_name}"] = view_path
3524
+ except (OSError, ValueError, RuntimeError) as e:
3525
+ logger.warning(
3526
+ "Failed to generate before/after CCI comparison for %s/%s: %s",
3527
+ subject_id,
3528
+ tract_name,
3529
+ e,
3530
+ )
3531
+
3532
+ # 4. Atlas comparison (uses MNI space tracts for subject views)
3533
+ if "atlas_comparison" not in skip_checks and atlas_files is not None and tract_name in atlas_files:
3534
+ try:
3535
+ atlas_file = atlas_files[tract_name]
3536
+ # Generate subject views in MNI space if available, otherwise skip
3537
+ if (
3538
+ subjects_mni_space is not None
3539
+ and subject_id in subjects_mni_space
3540
+ and tract_name in subjects_mni_space[subject_id]
3541
+ ):
3542
+ tract_file_mni = subjects_mni_space[subject_id][tract_name]
3543
+ # Generate subject views in MNI space
3544
+ subject_mni_views = self.generate_anatomical_views(
3545
+ tract_file_mni,
3546
+ ref_img=atlas_ref_img, # Use atlas ref image for MNI space
3547
+ output_dir=tract_output_dir,
3548
+ **kwargs,
3549
+ )
3550
+ # Add subject MNI views to results
3551
+ for view_name, view_path in subject_mni_views.items():
3552
+ tract_results[f"subject_mni_{view_name}"] = view_path
3553
+
3554
+ # Generate atlas views
3555
+ atlas_views = self.generate_atlas_views(
3556
+ atlas_file,
3557
+ atlas_ref_img=atlas_ref_img,
3558
+ flip_lr=flip_lr,
3559
+ output_dir=tract_output_dir,
3560
+ **kwargs,
3561
+ )
3562
+ # Add atlas views to results
3563
+ for view_name, view_path in atlas_views.items():
3564
+ tract_results[f"atlas_{view_name}"] = view_path
3565
+ except (OSError, ValueError, RuntimeError) as e:
3566
+ logger.warning(
3567
+ "Failed to generate atlas comparison views for %s/%s: %s",
3568
+ subject_id,
3569
+ tract_name,
3570
+ e,
3571
+ )
3572
+
3573
+ # 5. Shape similarity (uses MNI space tracts)
3574
+ if (
3575
+ "shape_similarity" not in skip_checks and atlas_files is not None and subjects_mni_space is not None
3576
+ ) and (
3577
+ subject_id in subjects_mni_space
3578
+ and tract_name in subjects_mni_space[subject_id]
3579
+ and tract_name in atlas_files
3580
+ ):
3581
+ try:
3582
+ # Use MNI space tract for shape similarity
3583
+ tract_file_mni = subjects_mni_space[subject_id][tract_name]
3584
+ atlas_file = atlas_files[tract_name]
3585
+ # Calculate shape similarity score
3586
+ similarity_score = self.calculate_shape_similarity(
3587
+ tract_file_mni,
3588
+ atlas_file,
3589
+ atlas_ref_img=atlas_ref_img,
3590
+ flip_lr=flip_lr,
3591
+ **kwargs,
3592
+ )
3593
+ tract_results["shape_similarity_score"] = str(similarity_score)
3594
+
3595
+ # Visualize shape similarity
3596
+ similarity_views = self.visualize_shape_similarity(
3597
+ tract_file_mni,
3598
+ atlas_file,
3599
+ atlas_ref_img=atlas_ref_img,
3600
+ flip_lr=flip_lr,
3601
+ output_dir=tract_output_dir,
3602
+ **kwargs,
3603
+ )
3604
+ # Add similarity views to results
3605
+ for view_name, view_path in similarity_views.items():
3606
+ tract_results[f"similarity_{view_name}"] = view_path
3607
+ except (OSError, ValueError, RuntimeError, IndexError) as e:
3608
+ logger.warning(
3609
+ "Failed to calculate/visualize shape similarity for %s/%s: %s",
3610
+ subject_id,
3611
+ tract_name,
3612
+ e,
3613
+ )
3614
+
3615
+ # 6. AFQ profile (requires metric files per subject and atlas files as model files)
3616
+ if ("afq_profile" not in skip_checks and metric_files is not None and atlas_files is not None) and (
3617
+ subject_id in metric_files and tract_name in atlas_files
3618
+ ):
3619
+ model_file = atlas_files[tract_name] # Use atlas file as model file
3620
+ for metric_name, metric_file in metric_files[subject_id].items():
3621
+ try:
3622
+ afq_plots = self.plot_afq(
3623
+ metric_file,
3624
+ metric_name,
3625
+ tract_file,
3626
+ model_file,
3627
+ ref_img=subject_ref_img,
3628
+ output_dir=tract_output_dir,
3629
+ **kwargs,
3630
+ )
3631
+ # Add AFQ plots to results
3632
+ for plot_type, plot_path in afq_plots.items():
3633
+ tract_results[f"afq_{metric_name}_{plot_type}"] = plot_path
3634
+ except (OSError, ValueError, RuntimeError, IndexError) as e:
3635
+ logger.warning(
3636
+ "Failed to generate AFQ profile for %s/%s/%s: %s",
3637
+ subject_id,
3638
+ tract_name,
3639
+ metric_name,
3640
+ e,
3641
+ )
3642
+
3643
+ # 7. Bundle assignment (uses MNI space tracts and atlas files as model files)
3644
+ if (
3645
+ "bundle_assignment" not in skip_checks and atlas_files is not None and subjects_mni_space is not None
3646
+ ) and (
3647
+ subject_id in subjects_mni_space
3648
+ and tract_name in subjects_mni_space[subject_id]
3649
+ and tract_name in atlas_files
3650
+ ):
3651
+ try:
3652
+ # Use MNI space tract for bundle assignment
3653
+ tract_file_mni = subjects_mni_space[subject_id][tract_name]
3654
+ model_file = atlas_files[tract_name] # Use atlas file as model file
3655
+ assignment_views = self.visualize_bundle_assignment(
3656
+ tract_file_mni,
3657
+ model_file,
3658
+ output_dir=tract_output_dir,
3659
+ ref_img=subject_ref_img,
3660
+ **kwargs,
3661
+ )
3662
+ # Add assignment views to results
3663
+ for view_name, view_path in assignment_views.items():
3664
+ tract_results[f"assignment_{view_name}"] = view_path
3665
+ except (OSError, ValueError, RuntimeError, IndexError) as e:
3666
+ logger.warning(
3667
+ "Failed to generate bundle assignment for %s/%s: %s",
3668
+ subject_id,
3669
+ tract_name,
3670
+ e,
3671
+ )
3672
+
3673
+ except (OSError, ValueError, RuntimeError):
3674
+ logger.exception("Error processing %s/%s", subject_id, tract_name)
3675
+ finally:
3676
+ # Clean up memory after processing this tract
3677
+ gc.collect()
3678
+
3679
+ return tract_results
3680
+
3681
+ def run_quality_check_workflow(
3682
+ self,
3683
+ subjects_original_space: dict[str, dict[str, str | Path]],
3684
+ ref_img: str | Path | dict[str, str | Path] | None = None,
3685
+ *,
3686
+ subjects_mni_space: dict[str, dict[str, str | Path]] | None = None,
3687
+ atlas_files: dict[str, str | Path] | None = None,
3688
+ metric_files: dict[str, dict[str, str | Path]] | None = None,
3689
+ atlas_ref_img: str | Path | None = None,
3690
+ flip_lr: bool = False,
3691
+ output_dir: str | Path | None = None,
3692
+ html_output: str | Path | None = None,
3693
+ skip_checks: list[str] | None = None,
3694
+ n_jobs: int | None = None,
3695
+ **kwargs: Any,
3696
+ ) -> dict[str, dict[str, dict[str, str | Path]]]:
3697
+ """Run comprehensive quality checks for multiple subjects and tracts.
3698
+
3699
+ This workflow function orchestrates all available quality check methods
3700
+ for each subject/tract combination and generates an HTML report.
3701
+
3702
+ Different quality checks require tracts in different coordinate spaces:
3703
+ - **Original space** (subjects_original_space): Used for anatomical views,
3704
+ CCI calculations, before/after CCI comparison, and AFQ profiles
3705
+ (which need to align with subject-specific metric files).
3706
+ - **MNI/Atlas space** (subjects_mni_space): Used for shape similarity
3707
+ calculations and bundle assignment (which compare with atlas files).
3708
+
3709
+ Parameters
3710
+ ----------
3711
+ subjects_original_space : dict[str, dict[str, str | Path]]
3712
+ Dictionary mapping subject IDs to their tract files in original/native space.
3713
+ Format: {subject_id: {tract_name: tract_file_path}}
3714
+ Example:
3715
+ {
3716
+ "sub-001": {
3717
+ "AF_L": "path/to/sub-001_AF_L_original.trk",
3718
+ "AF_R": "path/to/sub-001_AF_R_original.trk"
3719
+ },
3720
+ "sub-002": {
3721
+ "AF_L": "path/to/sub-002_AF_L_original.trk"
3722
+ }
3723
+ }
3724
+ Used for: anatomical views, CCI, before/after CCI, AFQ profiles.
3725
+ ref_img : str | Path | dict[str, str | Path] | None, optional
3726
+ Reference image(s) for subjects. Can be:
3727
+ - A single path (str | Path): Used for all subjects
3728
+ - A dictionary mapping subject IDs to reference images:
3729
+ {subject_id: ref_image_path}
3730
+ - None: Uses the reference image set during initialization or via
3731
+ `set_reference_image()` for all subjects
3732
+ Example for per-subject reference images:
3733
+ {"sub-001": "path/to/sub-001_t1w.nii.gz", "sub-002": "path/to/sub-002_t1w.nii.gz"}
3734
+ subjects_mni_space : dict[str, dict[str, str | Path]] | None, optional
3735
+ Dictionary mapping subject IDs to their tract files in MNI/atlas space.
3736
+ Format: {subject_id: {tract_name: tract_file_path}}
3737
+ Example:
3738
+ {
3739
+ "sub-001": {
3740
+ "AF_L": "path/to/sub-001_AF_L_mni.trk",
3741
+ "AF_R": "path/to/sub-001_AF_R_mni.trk"
3742
+ },
3743
+ "sub-002": {
3744
+ "AF_L": "path/to/sub-002_AF_L_mni.trk"
3745
+ }
3746
+ }
3747
+ Used for: shape similarity, bundle assignment.
3748
+ If None, these checks will be skipped.
3749
+ atlas_files : dict[str, str | Path] | None, optional
3750
+ Dictionary mapping tract names to their corresponding atlas/model files.
3751
+ These files are shared across all subjects and used for:
3752
+ - Atlas comparison visualizations
3753
+ - Shape similarity calculations
3754
+ - AFQ profile calculations (as model files)
3755
+ - Bundle assignment visualizations (as model files)
3756
+ Format: {tract_name: atlas_file_path}
3757
+ Example: {"AF_L": "path/to/atlas_AF_L.trk", "AF_R": "path/to/atlas_AF_R.trk"}
3758
+ metric_files : dict[str, dict[str, str | Path]] | None, optional
3759
+ Dictionary mapping subject IDs to their metric files.
3760
+ All tracts within a subject will use the same metric files.
3761
+ Format: {subject_id: {metric_name: metric_file_path}}
3762
+ Example:
3763
+ {
3764
+ "sub-001": {"FA": "path/to/sub-001_FA.nii.gz", "MD": "path/to/sub-001_MD.nii.gz"},
3765
+ "sub-002": {"FA": "path/to/sub-002_FA.nii.gz"}
3766
+ }
3767
+ If provided, AFQ profile calculations will be run for all tracts in each subject.
3768
+ atlas_ref_img : str | Path | None, optional
3769
+ Path to the reference image matching the atlas coordinate space
3770
+ (e.g., MNI template). Required if atlas files are in a different
3771
+ coordinate space than subject tracts.
3772
+ flip_lr : bool, optional
3773
+ Whether to flip left-right (X-axis) when transforming atlas.
3774
+ Default is False.
3775
+ output_dir : str | Path | None, optional
3776
+ Output directory for generated files. If None, uses the output
3777
+ directory set during initialization.
3778
+ html_output : str | Path | None, optional
3779
+ Path for the HTML report file. If None, creates "quality_check_report.html"
3780
+ in the output directory.
3781
+ skip_checks : list[str] | None, optional
3782
+ List of quality checks to skip. Valid options:
3783
+ - "anatomical_views": Skip standard anatomical views
3784
+ - "atlas_comparison": Skip atlas comparison views
3785
+ - "cci": Skip CCI calculation and visualization
3786
+ - "before_after_cci": Skip before/after CCI comparison
3787
+ - "afq_profile": Skip AFQ profile visualization
3788
+ - "bundle_assignment": Skip bundle assignment visualization
3789
+ - "shape_similarity": Skip shape similarity calculation and visualization
3790
+ n_jobs : int | None, optional
3791
+ Number of parallel jobs to run for processing multiple subjects/tracts.
3792
+ If None, uses the value set during initialization (default: 1).
3793
+ Use -1 to automatically determine optimal number based on available
3794
+ resources (respects SLURM allocations and OpenMP thread settings to
3795
+ prevent oversubscription). Only effective when processing multiple
3796
+ subjects/tracts. Default is None.
3797
+
3798
+ Note: When running under SLURM, -1 will automatically use
3799
+ SLURM_CPUS_PER_TASK or SLURM_JOB_CPUS_PER_NODE. If OMP_NUM_THREADS is set,
3800
+ it will divide the available CPUs by the number of OpenMP threads to
3801
+ prevent resource contention.
3802
+ **kwargs
3803
+ Additional keyword arguments passed to individual quality check methods.
3804
+
3805
+ Returns
3806
+ -------
3807
+ dict[str, dict[str, dict[str, str | Path]]]
3808
+ Nested dictionary structure: {subject_id: {tract_name: {media_type: file_path}}}
3809
+ This structure is compatible with `create_quality_check_html()`.
3810
+
3811
+ Raises
3812
+ ------
3813
+ InvalidInputError
3814
+ If required files are missing or invalid.
3815
+ TractographyVisualizationError
3816
+ If quality check workflow fails.
3817
+
3818
+ Examples
3819
+ --------
3820
+ Single subject or all subjects share same reference image:
3821
+ >>> visualizer = TractographyVisualizer(output_directory="output/")
3822
+ >>> subjects_original = {
3823
+ ... "sub-001": {
3824
+ ... "AF_L": "sub-001_AF_L_original.trk",
3825
+ ... "AF_R": "sub-001_AF_R_original.trk",
3826
+ ... }
3827
+ ... }
3828
+ >>> results = visualizer.run_quality_check_workflow(
3829
+ ... subjects_original_space=subjects_original,
3830
+ ... ref_img="shared_t1w.nii.gz", # Single image for all subjects
3831
+ ... html_output="quality_report.html",
3832
+ ... )
3833
+
3834
+ Multiple subjects with different reference images:
3835
+ >>> visualizer = TractographyVisualizer(output_directory="output/")
3836
+ >>> subjects_original = {
3837
+ ... "sub-001": {"AF_L": "sub-001_AF_L_original.trk"},
3838
+ ... "sub-002": {"AF_L": "sub-002_AF_L_original.trk"},
3839
+ ... }
3840
+ >>> # Each subject has its own reference image
3841
+ >>> ref_images = {
3842
+ ... "sub-001": "sub-001_t1w.nii.gz",
3843
+ ... "sub-002": "sub-002_t1w.nii.gz",
3844
+ ... }
3845
+ >>> subjects_mni = {
3846
+ ... "sub-001": {"AF_L": "sub-001_AF_L_mni.trk"},
3847
+ ... "sub-002": {"AF_L": "sub-002_AF_L_mni.trk"},
3848
+ ... }
3849
+ >>> atlas_files = {"AF_L": "atlas_AF_L.trk"}
3850
+ >>> metric_files = {
3851
+ ... "sub-001": {"FA": "sub-001_FA.nii.gz"},
3852
+ ... "sub-002": {"FA": "sub-002_FA.nii.gz"},
3853
+ ... }
3854
+ >>> results = visualizer.run_quality_check_workflow(
3855
+ ... subjects_original_space=subjects_original,
3856
+ ... ref_img=ref_images, # Dictionary mapping subject_id -> ref_image
3857
+ ... subjects_mni_space=subjects_mni,
3858
+ ... atlas_files=atlas_files,
3859
+ ... metric_files=metric_files,
3860
+ ... html_output="quality_report.html",
3861
+ ... )
3862
+ """
3863
+ # Initialize XVFB (X Virtual Framebuffer) if requested for headless environments
3864
+ vdisplay = None
3865
+ if os.environ.get("XVFB", "").lower() in ("1", "true", "yes"):
3866
+ logger.info("Initializing XVFB for headless rendering")
3867
+ vdisplay = Xvfb()
3868
+ vdisplay.start()
3869
+
3870
+ # Get output directory
3871
+ if output_dir is None:
3872
+ if self._output_directory is None:
3873
+ raise InvalidInputError(
3874
+ "No output directory provided. Set it via constructor or "
3875
+ "set_output_directory() method, or pass it as an argument.",
3876
+ )
3877
+ output_dir = self._output_directory
3878
+ else:
3879
+ output_dir = Path(output_dir)
3880
+ output_dir.mkdir(parents=True, exist_ok=True)
3881
+
3882
+ # Handle reference image(s) - can be single path or dict mapping subject_id -> path
3883
+ if ref_img is None:
3884
+ if self._reference_image is None:
3885
+ raise InvalidInputError(
3886
+ "No reference image provided. Set it via constructor or "
3887
+ "set_reference_image() method, or pass it as an argument.",
3888
+ )
3889
+ # Use instance reference image for all subjects
3890
+ subject_ref_imgs: dict[str, Path] = dict.fromkeys(subjects_original_space.keys(), self._reference_image)
3891
+ elif isinstance(ref_img, dict):
3892
+ # Dictionary mapping subject IDs to reference images
3893
+ subject_ref_imgs = {subject_id: Path(path) for subject_id, path in ref_img.items()}
3894
+ # Validate all subjects have reference images
3895
+ missing = set(subjects_original_space.keys()) - set(subject_ref_imgs.keys())
3896
+ if missing:
3897
+ raise InvalidInputError(
3898
+ f"Missing reference images for subjects: {', '.join(sorted(missing))}",
3899
+ )
3900
+ else:
3901
+ # Single reference image for all subjects
3902
+ single_ref_img = Path(ref_img)
3903
+ subject_ref_imgs = dict.fromkeys(subjects_original_space.keys(), single_ref_img)
3904
+
3905
+ # Set default skip_checks
3906
+ if skip_checks is None:
3907
+ skip_checks = []
3908
+
3909
+ # Determine number of jobs to use
3910
+ if n_jobs is None:
3911
+ n_jobs = self.n_jobs
3912
+ elif n_jobs == -1:
3913
+ # Use optimal n_jobs considering SLURM and OpenMP settings
3914
+ base_n_jobs = _get_optimal_n_jobs()
3915
+ # Further reduce if memory is limited
3916
+ n_jobs = _get_n_jobs_with_memory_limit(
3917
+ base_n_jobs,
3918
+ estimated_memory_per_job_mb=2000.0, # ~2GB per worker
3919
+ safety_margin=0.2,
3920
+ )
3921
+ else:
3922
+ n_jobs = max(1, n_jobs)
3923
+ # Still check memory even if n_jobs is explicitly set
3924
+ if psutil is not None:
3925
+ n_jobs = _get_n_jobs_with_memory_limit(
3926
+ n_jobs,
3927
+ estimated_memory_per_job_mb=2000.0,
3928
+ safety_margin=0.2,
3929
+ )
3930
+
3931
+ # Log the final n_jobs value for debugging
3932
+ logger.debug("Using n_jobs=%d for parallel processing", n_jobs)
3933
+
3934
+ # Initialize results dictionary
3935
+ results: dict[str, dict[str, dict[str, str | Path]]] = {}
3936
+
3937
+ # Prepare all tasks (subject_id, tract_name, tract_file combinations)
3938
+ tasks: list[tuple[str, str, str | Path, Path, Path]] = []
3939
+ for subject_id, tracts in subjects_original_space.items():
3940
+ subject_ref_img = subject_ref_imgs[subject_id]
3941
+ subject_output_dir = output_dir / subject_id
3942
+ subject_output_dir.mkdir(parents=True, exist_ok=True)
3943
+
3944
+ for tract_name, tract_file in tracts.items():
3945
+ tract_path = Path(tract_file)
3946
+ if not tract_path.exists():
3947
+ raise FileNotFoundError(f"Tract file not found: {tract_file}")
3948
+
3949
+ tract_output_dir = subject_output_dir / tract_name
3950
+ tract_output_dir.mkdir(parents=True, exist_ok=True)
3951
+
3952
+ tasks.append((subject_id, tract_name, tract_file, subject_ref_img, tract_output_dir))
3953
+
3954
+ # Prepare visualizer parameters for worker processes
3955
+ visualizer_params = {
3956
+ "gif_size": self.gif_size,
3957
+ "gif_duration": self.gif_duration,
3958
+ "gif_palette_size": self.gif_palette_size,
3959
+ "gif_frames": self.gif_frames,
3960
+ "min_streamline_length": self.min_streamline_length,
3961
+ "cci_threshold": self.cci_threshold,
3962
+ "afq_resample_points": self.afq_resample_points,
3963
+ "n_jobs": 1, # Workers don't need parallelization
3964
+ }
3965
+
3966
+ # Process tasks in parallel or sequentially
3967
+ if n_jobs > 1 and len(tasks) > 1:
3968
+ logger.info("Processing %d tracts using %d workers", len(tasks), n_jobs)
3969
+ try:
3970
+ with ProcessPoolExecutor(max_workers=n_jobs) as executor:
3971
+ futures = {
3972
+ executor.submit(
3973
+ _process_tract_worker,
3974
+ subject_id=subject_id,
3975
+ tract_name=tract_name,
3976
+ tract_file=tract_file,
3977
+ subject_ref_img=subject_ref_img,
3978
+ tract_output_dir=tract_output_dir,
3979
+ subjects_mni_space=subjects_mni_space,
3980
+ atlas_files=atlas_files,
3981
+ metric_files=metric_files,
3982
+ atlas_ref_img=atlas_ref_img,
3983
+ flip_lr=flip_lr,
3984
+ skip_checks=skip_checks,
3985
+ visualizer_params=visualizer_params,
3986
+ **kwargs,
3987
+ ): (subject_id, tract_name)
3988
+ for subject_id, tract_name, tract_file, subject_ref_img, tract_output_dir in tasks
3989
+ }
3990
+
3991
+ # Collect results as they complete
3992
+ broken_pool_detected = False
3993
+ completed_count = 0
3994
+ try:
3995
+ for future in as_completed(futures):
3996
+ subject_id, tract_name = futures[future]
3997
+ try:
3998
+ # Get result with timeout - this is where BrokenProcessPool is raised
3999
+ result_subject_id, result_tract_name, tract_results = future.result(timeout=None)
4000
+ if result_subject_id not in results:
4001
+ results[result_subject_id] = {}
4002
+ results[result_subject_id][result_tract_name] = tract_results
4003
+ # Clean up the future reference
4004
+ del tract_results
4005
+ except RuntimeError as e:
4006
+ # RuntimeError catches BrokenProcessPool (which is a subclass of RuntimeError)
4007
+ # Check if this is a BrokenProcessPool by checking the error message and type name
4008
+ error_type = type(e).__name__
4009
+ error_msg = str(e)
4010
+ if (
4011
+ "BrokenProcessPool" in error_type
4012
+ or "BrokenProcessPool" in error_msg
4013
+ or "process pool" in error_msg.lower()
4014
+ or "was terminated abruptly" in error_msg
4015
+ ):
4016
+ broken_pool_detected = True
4017
+ logger.exception(
4018
+ "BrokenProcessPool detected for %s/%s. Worker process may have crashed",
4019
+ subject_id,
4020
+ tract_name,
4021
+ )
4022
+ # Mark this task as failed
4023
+ if subject_id not in results:
4024
+ results[subject_id] = {}
4025
+ if tract_name not in results[subject_id]:
4026
+ results[subject_id][tract_name] = {}
4027
+ # Break out of the loop immediately - pool is broken, can't process more
4028
+ break
4029
+ logger.exception("RuntimeError processing %s/%s", subject_id, tract_name)
4030
+ if subject_id not in results:
4031
+ results[subject_id] = {}
4032
+ if tract_name not in results[subject_id]:
4033
+ results[subject_id][tract_name] = {}
4034
+ except (OSError, ValueError) as e:
4035
+ # OSError catches system-level errors
4036
+ # ValueError catches other processing errors
4037
+ logger.exception("Error processing %s/%s: %s", subject_id, tract_name, type(e).__name__)
4038
+ if subject_id not in results:
4039
+ results[subject_id] = {}
4040
+ if tract_name not in results[subject_id]:
4041
+ results[subject_id][tract_name] = {}
4042
+ except Exception as e:
4043
+ # Catch any other unexpected exceptions, including BrokenProcessPool
4044
+ error_type = type(e).__name__
4045
+ error_msg = str(e)
4046
+ if (
4047
+ "BrokenProcessPool" in error_type
4048
+ or "BrokenProcessPool" in error_msg
4049
+ or "process pool" in error_msg.lower()
4050
+ or "was terminated abruptly" in error_msg
4051
+ ):
4052
+ broken_pool_detected = True
4053
+ logger.exception(
4054
+ "BrokenProcessPool detected for %s/%s (caught as Exception). Worker process may have crashed",
4055
+ subject_id,
4056
+ tract_name,
4057
+ )
4058
+ # Mark this task as failed
4059
+ if subject_id not in results:
4060
+ results[subject_id] = {}
4061
+ if tract_name not in results[subject_id]:
4062
+ results[subject_id][tract_name] = {}
4063
+ # Break out of the loop immediately - pool is broken, can't process more
4064
+ break
4065
+ logger.exception(
4066
+ "Unexpected error processing %s/%s: %s",
4067
+ subject_id,
4068
+ tract_name,
4069
+ type(e).__name__,
4070
+ )
4071
+ if subject_id not in results:
4072
+ results[subject_id] = {}
4073
+ if tract_name not in results[subject_id]:
4074
+ results[subject_id][tract_name] = {}
4075
+ finally:
4076
+ # Increment counter regardless of success/failure
4077
+ completed_count += 1
4078
+
4079
+ # Clean up future reference and remove from futures dict
4080
+ # This is critical to prevent memory accumulation
4081
+ futures.pop(future, None)
4082
+ del future
4083
+
4084
+ # Periodic garbage collection every 10 completed tasks
4085
+ # This helps prevent memory buildup during long-running jobs
4086
+ if completed_count > 0 and completed_count % 10 == 0:
4087
+ gc.collect()
4088
+ except RuntimeError as e:
4089
+ # Catch BrokenProcessPool that might break out of the loop
4090
+ error_type = type(e).__name__
4091
+ error_msg = str(e)
4092
+ if (
4093
+ "BrokenProcessPool" in error_type
4094
+ or "BrokenProcessPool" in error_msg
4095
+ or "process pool" in error_msg.lower()
4096
+ or "was terminated abruptly" in error_msg
4097
+ ):
4098
+ broken_pool_detected = True
4099
+ logger.exception(
4100
+ "BrokenProcessPool detected during result collection. Falling back to sequential processing",
4101
+ )
4102
+ else:
4103
+ # Re-raise if it's a different RuntimeError
4104
+ raise
4105
+ except Exception as e:
4106
+ # Catch any other exceptions that might break the loop
4107
+ error_type = type(e).__name__
4108
+ error_msg = str(e)
4109
+ if (
4110
+ "BrokenProcessPool" in error_type
4111
+ or "BrokenProcessPool" in error_msg
4112
+ or "process pool" in error_msg.lower()
4113
+ or "was terminated abruptly" in error_msg
4114
+ ):
4115
+ broken_pool_detected = True
4116
+ logger.exception(
4117
+ "BrokenProcessPool detected during result collection (caught as Exception). Falling back to sequential processing",
4118
+ )
4119
+ else:
4120
+ # Log but don't re-raise - try to continue with what we have
4121
+ logger.exception("Unexpected error during result collection: %s", type(e).__name__)
4122
+
4123
+ # If we detected a broken process pool, break out and fall back to sequential
4124
+ if broken_pool_detected:
4125
+ # Raise error to trigger fallback to sequential processing
4126
+ # The exception will be caught by the outer try-except block
4127
+ raise RuntimeError("BrokenProcessPool detected, falling back to sequential processing") # noqa: TRY301
4128
+
4129
+ # Clear futures dictionary to free memory
4130
+ futures.clear()
4131
+
4132
+ # Force garbage collection after all parallel tasks complete
4133
+ # Run multiple times to handle circular references
4134
+ for _ in range(3):
4135
+ gc.collect()
4136
+ except RuntimeError as e:
4137
+ # RuntimeError catches BrokenProcessPool (which is a subclass of RuntimeError)
4138
+ # and other runtime errors that might occur with process pools
4139
+ error_msg = str(e)
4140
+ if "BrokenProcessPool" in error_msg or "falling back" in error_msg.lower():
4141
+ logger.warning(
4142
+ "Process pool error occurred (worker process may have crashed with core dump). "
4143
+ "Falling back to sequential processing. "
4144
+ "If this persists, try: (1) reducing n_jobs, (2) increasing memory allocation, "
4145
+ "or (3) using n_jobs=1 to force sequential processing.",
4146
+ )
4147
+ else:
4148
+ logger.exception("RuntimeError in process pool. Falling back to sequential processing")
4149
+
4150
+ # Fall back to sequential processing if process pool fails
4151
+ # Only process tasks that haven't been completed yet
4152
+ remaining_tasks = [
4153
+ (s_id, t_name, t_file, s_ref, t_out)
4154
+ for s_id, t_name, t_file, s_ref, t_out in tasks
4155
+ if s_id not in results or t_name not in results.get(s_id, {})
4156
+ ]
4157
+
4158
+ if remaining_tasks:
4159
+ logger.info("Processing %d remaining tracts sequentially (fallback)", len(remaining_tasks))
4160
+ for subject_id, tract_name, tract_file, subject_ref_img, tract_output_dir in remaining_tasks:
4161
+ if subject_id not in results:
4162
+ results[subject_id] = {}
4163
+ # Skip if already processed
4164
+ if tract_name in results[subject_id]:
4165
+ continue
4166
+ try:
4167
+ tract_result = self._process_single_tract(
4168
+ subject_id=subject_id,
4169
+ tract_name=tract_name,
4170
+ tract_file=tract_file,
4171
+ subject_ref_img=subject_ref_img,
4172
+ tract_output_dir=tract_output_dir,
4173
+ subjects_mni_space=subjects_mni_space,
4174
+ atlas_files=atlas_files,
4175
+ metric_files=metric_files,
4176
+ atlas_ref_img=atlas_ref_img,
4177
+ flip_lr=flip_lr,
4178
+ skip_checks=skip_checks,
4179
+ **kwargs,
4180
+ )
4181
+ results[subject_id][tract_name] = tract_result
4182
+ # Clean up result reference
4183
+ del tract_result
4184
+ except (OSError, ValueError, RuntimeError, MemoryError) as process_error:
4185
+ logger.exception(
4186
+ "Error processing %s/%s in fallback mode (%s)",
4187
+ subject_id,
4188
+ tract_name,
4189
+ type(process_error).__name__,
4190
+ )
4191
+ results[subject_id][tract_name] = {}
4192
+ except Exception as process_error:
4193
+ logger.exception(
4194
+ "Unexpected error processing %s/%s in fallback mode (%s)",
4195
+ subject_id,
4196
+ tract_name,
4197
+ type(process_error).__name__,
4198
+ )
4199
+ results[subject_id][tract_name] = {}
4200
+ finally:
4201
+ # Force garbage collection after each tract in fallback mode
4202
+ # Run multiple times to handle circular references in VTK objects
4203
+ for _ in range(2):
4204
+ gc.collect()
4205
+ else:
4206
+ logger.info("All tasks already completed, skipping fallback processing")
4207
+ else:
4208
+ # Sequential processing
4209
+ logger.info("Processing %d tracts sequentially", len(tasks))
4210
+ for subject_id, tract_name, tract_file, subject_ref_img, tract_output_dir in tasks:
4211
+ if subject_id not in results:
4212
+ results[subject_id] = {}
4213
+ try:
4214
+ tract_result = self._process_single_tract(
4215
+ subject_id=subject_id,
4216
+ tract_name=tract_name,
4217
+ tract_file=tract_file,
4218
+ subject_ref_img=subject_ref_img,
4219
+ tract_output_dir=tract_output_dir,
4220
+ subjects_mni_space=subjects_mni_space,
4221
+ atlas_files=atlas_files,
4222
+ metric_files=metric_files,
4223
+ atlas_ref_img=atlas_ref_img,
4224
+ flip_lr=flip_lr,
4225
+ skip_checks=skip_checks,
4226
+ **kwargs,
4227
+ )
4228
+ results[subject_id][tract_name] = tract_result
4229
+ # Clean up result reference
4230
+ del tract_result
4231
+ finally:
4232
+ # Clean up after each tract in sequential processing
4233
+ # Run multiple times to handle circular references in VTK objects
4234
+ for _ in range(2):
4235
+ gc.collect()
4236
+
4237
+ # Generate HTML report
4238
+ html_output = output_dir / "quality_check_report.html" if html_output is None else Path(html_output)
4239
+
4240
+ # Convert Path objects to strings for HTML function
4241
+ results_for_html: dict[str, dict[str, dict[str, str]]] = {}
4242
+ for subject_id, subject_tracts in results.items():
4243
+ results_for_html[subject_id] = {}
4244
+ if isinstance(subject_tracts, dict):
4245
+ for tract_name, media_dict in subject_tracts.items():
4246
+ results_for_html[subject_id][tract_name] = {}
4247
+ if isinstance(media_dict, dict):
4248
+ for media_type, file_path in media_dict.items():
4249
+ # Convert Path to string, or keep as string/number
4250
+ if isinstance(file_path, Path):
4251
+ results_for_html[subject_id][tract_name][media_type] = str(file_path)
4252
+ else:
4253
+ results_for_html[subject_id][tract_name][media_type] = str(file_path)
4254
+
4255
+ try:
4256
+ create_quality_check_html(
4257
+ results_for_html,
4258
+ str(html_output),
4259
+ title="Tractography Quality Check Report",
4260
+ )
4261
+ logger.info("Quality check report generated: %s", html_output)
4262
+ except (OSError, ValueError, RuntimeError) as e:
4263
+ logger.warning("Failed to generate HTML report: %s", e)
4264
+ finally:
4265
+ # Clean up HTML conversion dictionary after use
4266
+ del results_for_html
4267
+ gc.collect()
4268
+ # Stop XVFB if it was started
4269
+ if vdisplay is not None:
4270
+ logger.info("Stopping XVFB")
4271
+ vdisplay.stop()
4272
+ return results