pytractoviz 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytractoviz/__init__.py +11 -0
- pytractoviz/__main__.py +14 -0
- pytractoviz/_internal/__init__.py +0 -0
- pytractoviz/_internal/cli.py +59 -0
- pytractoviz/_internal/debug.py +110 -0
- pytractoviz/html.py +845 -0
- pytractoviz/py.typed +0 -0
- pytractoviz/utils.py +220 -0
- pytractoviz/viz.py +4272 -0
- pytractoviz-0.2.14.dist-info/METADATA +53 -0
- pytractoviz-0.2.14.dist-info/RECORD +14 -0
- pytractoviz-0.2.14.dist-info/WHEEL +4 -0
- pytractoviz-0.2.14.dist-info/entry_points.txt +5 -0
- pytractoviz-0.2.14.dist-info/licenses/LICENSE +21 -0
pytractoviz/viz.py
ADDED
|
@@ -0,0 +1,4272 @@
|
|
|
1
|
+
"""Visualization module for diffusion tractography."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import contextlib
|
|
6
|
+
import gc
|
|
7
|
+
import logging
|
|
8
|
+
import multiprocessing
|
|
9
|
+
import os
|
|
10
|
+
import resource
|
|
11
|
+
import sys
|
|
12
|
+
import tempfile
|
|
13
|
+
import tracemalloc
|
|
14
|
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
import psutil
|
|
20
|
+
except ImportError:
|
|
21
|
+
psutil = None
|
|
22
|
+
|
|
23
|
+
import imageio
|
|
24
|
+
import matplotlib.pyplot as plt
|
|
25
|
+
import nibabel as nib
|
|
26
|
+
import numpy as np
|
|
27
|
+
import vtk
|
|
28
|
+
from dipy.io.stateful_tractogram import Space, StatefulTractogram
|
|
29
|
+
from dipy.io.streamline import load_trk
|
|
30
|
+
from dipy.segment.bundles import bundle_shape_similarity
|
|
31
|
+
from dipy.segment.clustering import QuickBundles
|
|
32
|
+
from dipy.segment.featurespeed import ResampleFeature
|
|
33
|
+
from dipy.segment.metricspeed import AveragePointwiseEuclideanMetric
|
|
34
|
+
from dipy.stats.analysis import afq_profile, assignment_map, gaussian_weights
|
|
35
|
+
from dipy.tracking.streamline import (
|
|
36
|
+
Streamlines,
|
|
37
|
+
cluster_confidence,
|
|
38
|
+
orient_by_streamline,
|
|
39
|
+
transform_streamlines,
|
|
40
|
+
)
|
|
41
|
+
from dipy.tracking.utils import length
|
|
42
|
+
from fury import actor, window
|
|
43
|
+
from fury.colormap import create_colormap
|
|
44
|
+
from PIL import Image
|
|
45
|
+
from scipy.spatial.transform import Rotation
|
|
46
|
+
from xvfbwrapper import Xvfb
|
|
47
|
+
|
|
48
|
+
from pytractoviz.html import create_quality_check_html
|
|
49
|
+
from pytractoviz.utils import (
|
|
50
|
+
ANATOMICAL_VIEW_ANGLES,
|
|
51
|
+
calculate_bbox_size,
|
|
52
|
+
calculate_centroid,
|
|
53
|
+
calculate_combined_bbox_size,
|
|
54
|
+
calculate_combined_centroid,
|
|
55
|
+
calculate_direction_colors,
|
|
56
|
+
set_anatomical_camera,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
logger = logging.getLogger(__name__)
|
|
60
|
+
|
|
61
|
+
# Constants for memory estimation and warnings
|
|
62
|
+
MAX_POINTS_PER_STREAMLINE_FRAGMENTATION_THRESHOLD = 10000 # Streamlines with >10k points risk fragmentation
|
|
63
|
+
DENSE_TRACT_POINT_THRESHOLD = 1000000 # >1M total points triggers dense tract warning
|
|
64
|
+
VERY_LONG_STREAMLINE_POINT_THRESHOLD = 50000 # Streamlines with >50k points trigger warning
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _log_memory_usage(
|
|
68
|
+
label: str = "",
|
|
69
|
+
*,
|
|
70
|
+
enable_tracemalloc: bool = False,
|
|
71
|
+
log_level: int = logging.DEBUG,
|
|
72
|
+
) -> dict[str, float | int] | None:
|
|
73
|
+
"""Log current memory usage for debugging memory issues.
|
|
74
|
+
|
|
75
|
+
This function provides comprehensive memory monitoring using:
|
|
76
|
+
- psutil (if available): Process-level memory (RSS, VMS)
|
|
77
|
+
- tracemalloc (built-in): Python-level memory tracking
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
label : str, optional
|
|
82
|
+
Label to identify this memory checkpoint (e.g., "after loading tract").
|
|
83
|
+
enable_tracemalloc : bool, default=False
|
|
84
|
+
If True, also log top memory allocations from tracemalloc.
|
|
85
|
+
Note: tracemalloc must be started with tracemalloc.start() first.
|
|
86
|
+
log_level : int, default=logging.DEBUG
|
|
87
|
+
Logging level to use for memory information.
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
dict[str, float | int] | None
|
|
92
|
+
Dictionary with memory statistics, or None if psutil is not available.
|
|
93
|
+
Keys: 'rss_mb', 'vms_mb', 'percent', 'available_mb' (if psutil available).
|
|
94
|
+
|
|
95
|
+
Examples
|
|
96
|
+
--------
|
|
97
|
+
>>> # Basic usage
|
|
98
|
+
>>> _log_memory_usage("Before processing")
|
|
99
|
+
>>> # Process data...
|
|
100
|
+
>>> _log_memory_usage("After processing")
|
|
101
|
+
|
|
102
|
+
>>> # With tracemalloc for detailed tracking
|
|
103
|
+
>>> import tracemalloc
|
|
104
|
+
>>> tracemalloc.start()
|
|
105
|
+
>>> _log_memory_usage("Checkpoint 1", enable_tracemalloc=True)
|
|
106
|
+
"""
|
|
107
|
+
memory_info: dict[str, float | int] = {}
|
|
108
|
+
|
|
109
|
+
# Method 1: resource module (built-in, Unix/macOS/Linux, no dependencies)
|
|
110
|
+
try:
|
|
111
|
+
usage = resource.getrusage(resource.RUSAGE_SELF)
|
|
112
|
+
# ru_maxrss is in KB on macOS, MB on Linux
|
|
113
|
+
if sys.platform == "darwin": # macOS
|
|
114
|
+
resource_rss_mb = usage.ru_maxrss / 1024
|
|
115
|
+
else: # Linux
|
|
116
|
+
resource_rss_mb = usage.ru_maxrss
|
|
117
|
+
memory_info["resource_rss_mb"] = round(resource_rss_mb, 2)
|
|
118
|
+
logger.log(
|
|
119
|
+
log_level,
|
|
120
|
+
"Memory usage%s (resource): RSS=%.2f MB",
|
|
121
|
+
f" [{label}]" if label else "",
|
|
122
|
+
resource_rss_mb,
|
|
123
|
+
)
|
|
124
|
+
except (OSError, AttributeError):
|
|
125
|
+
# resource module not available on this platform
|
|
126
|
+
pass
|
|
127
|
+
|
|
128
|
+
# Method 2: Process-level memory using psutil (if available, more detailed)
|
|
129
|
+
if psutil is not None:
|
|
130
|
+
process = psutil.Process(os.getpid())
|
|
131
|
+
mem_info = process.memory_info()
|
|
132
|
+
mem_percent = process.memory_percent()
|
|
133
|
+
|
|
134
|
+
# Get system memory info
|
|
135
|
+
try:
|
|
136
|
+
sys_mem = psutil.virtual_memory()
|
|
137
|
+
available_mb = sys_mem.available / (1024**2)
|
|
138
|
+
except (OSError, RuntimeError, AttributeError):
|
|
139
|
+
available_mb = 0.0
|
|
140
|
+
|
|
141
|
+
rss_mb = mem_info.rss / (1024**2) # Resident Set Size
|
|
142
|
+
vms_mb = mem_info.vms / (1024**2) # Virtual Memory Size
|
|
143
|
+
|
|
144
|
+
memory_info = {
|
|
145
|
+
"rss_mb": round(rss_mb, 2),
|
|
146
|
+
"vms_mb": round(vms_mb, 2),
|
|
147
|
+
"percent": round(mem_percent, 2),
|
|
148
|
+
"available_mb": round(available_mb, 2),
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
logger.log(
|
|
152
|
+
log_level,
|
|
153
|
+
"Memory usage%s: RSS=%.2f MB, VMS=%.2f MB, Process=%.2f%%, System Available=%.2f MB",
|
|
154
|
+
f" [{label}]" if label else "",
|
|
155
|
+
rss_mb,
|
|
156
|
+
vms_mb,
|
|
157
|
+
mem_percent,
|
|
158
|
+
available_mb,
|
|
159
|
+
)
|
|
160
|
+
else:
|
|
161
|
+
logger.log(
|
|
162
|
+
log_level,
|
|
163
|
+
"Memory monitoring: psutil not available. Install with: pip install psutil",
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Python-level memory using tracemalloc (if enabled)
|
|
167
|
+
if enable_tracemalloc and tracemalloc.is_tracing():
|
|
168
|
+
snapshot = tracemalloc.take_snapshot()
|
|
169
|
+
top_stats = snapshot.statistics("lineno")
|
|
170
|
+
|
|
171
|
+
logger.log(log_level, "Top 5 memory allocations%s:", f" [{label}]" if label else "")
|
|
172
|
+
for index, stat in enumerate(top_stats[:5], 1):
|
|
173
|
+
logger.log(
|
|
174
|
+
log_level,
|
|
175
|
+
" #%d: %s: %.1f MB",
|
|
176
|
+
index,
|
|
177
|
+
stat.traceback[0],
|
|
178
|
+
stat.size / (1024**2),
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# Get current and peak memory
|
|
182
|
+
current, peak = tracemalloc.get_traced_memory()
|
|
183
|
+
logger.log(
|
|
184
|
+
log_level,
|
|
185
|
+
"Tracemalloc: Current=%.2f MB, Peak=%.2f MB",
|
|
186
|
+
current / (1024**2),
|
|
187
|
+
peak / (1024**2),
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
memory_info["tracemalloc_current_mb"] = round(current / (1024**2), 2)
|
|
191
|
+
memory_info["tracemalloc_peak_mb"] = round(peak / (1024**2), 2)
|
|
192
|
+
|
|
193
|
+
return memory_info if memory_info else None
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def _set_memory_limit(memory_limit_mb: float | None = None) -> None:
|
|
197
|
+
"""Set a hard memory limit for the current process.
|
|
198
|
+
|
|
199
|
+
This uses resource.setrlimit() to set a maximum virtual memory (address space)
|
|
200
|
+
limit. If the process exceeds this limit, it will be killed by the OS.
|
|
201
|
+
|
|
202
|
+
Parameters
|
|
203
|
+
----------
|
|
204
|
+
memory_limit_mb : float | None, optional
|
|
205
|
+
Maximum memory limit in MB. If None, no limit is set.
|
|
206
|
+
If set, the process will be killed if it exceeds this limit.
|
|
207
|
+
|
|
208
|
+
Examples
|
|
209
|
+
--------
|
|
210
|
+
>>> # Limit to 8 GB
|
|
211
|
+
>>> _set_memory_limit(8192)
|
|
212
|
+
|
|
213
|
+
>>> # Limit to 4 GB
|
|
214
|
+
>>> _set_memory_limit(4096)
|
|
215
|
+
|
|
216
|
+
Notes
|
|
217
|
+
-----
|
|
218
|
+
- This sets RLIMIT_AS (virtual memory/address space limit)
|
|
219
|
+
- On macOS, this is enforced by the kernel
|
|
220
|
+
- The limit applies to the entire process, including all threads
|
|
221
|
+
- Once set, the limit cannot be increased (only decreased)
|
|
222
|
+
"""
|
|
223
|
+
if memory_limit_mb is None:
|
|
224
|
+
return
|
|
225
|
+
|
|
226
|
+
try:
|
|
227
|
+
# Convert MB to bytes
|
|
228
|
+
memory_limit_bytes = int(memory_limit_mb * 1024 * 1024)
|
|
229
|
+
|
|
230
|
+
# Set virtual memory limit (RLIMIT_AS)
|
|
231
|
+
# This limits the total address space the process can use
|
|
232
|
+
resource.setrlimit(resource.RLIMIT_AS, (memory_limit_bytes, resource.RLIM_INFINITY))
|
|
233
|
+
|
|
234
|
+
logger.info("Memory limit set to %.2f MB (%.2f GB)", memory_limit_mb, memory_limit_mb / 1024)
|
|
235
|
+
except (OSError, ValueError) as e:
|
|
236
|
+
logger.warning("Failed to set memory limit: %s", e)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def _estimate_actor_memory_mb(streamlines: Streamlines, figure_size: tuple[int, int] = (800, 800)) -> float:
|
|
240
|
+
"""Estimate memory needed for creating VTK actors and rendering.
|
|
241
|
+
|
|
242
|
+
Parameters
|
|
243
|
+
----------
|
|
244
|
+
streamlines : Streamlines
|
|
245
|
+
The streamlines to visualize.
|
|
246
|
+
figure_size : tuple[int, int], default=(800, 800)
|
|
247
|
+
Size of output image in pixels.
|
|
248
|
+
|
|
249
|
+
Returns
|
|
250
|
+
-------
|
|
251
|
+
float
|
|
252
|
+
Estimated memory in MB needed for actor creation and rendering.
|
|
253
|
+
"""
|
|
254
|
+
# Count total points
|
|
255
|
+
total_points = sum(len(sl) for sl in streamlines)
|
|
256
|
+
num_streamlines = len(streamlines)
|
|
257
|
+
|
|
258
|
+
# Calculate average points per streamline
|
|
259
|
+
avg_points_per_sl = total_points / num_streamlines if num_streamlines > 0 else 0
|
|
260
|
+
max_points_per_sl = max(len(sl) for sl in streamlines) if streamlines else 0
|
|
261
|
+
|
|
262
|
+
# Estimate memory:
|
|
263
|
+
# - VTK actor data: ~200 bytes per point (coordinates, colors, normals, etc.)
|
|
264
|
+
# (increased from 100 to account for VTK's internal structures)
|
|
265
|
+
# - Scene rendering buffer: image_size * 4 bytes (RGBA) * 3 (triple buffering for safety)
|
|
266
|
+
# - Additional overhead: ~30% for VTK internal structures and fragmentation
|
|
267
|
+
# - Large contiguous allocation penalty: if max streamline is very long, VTK may need
|
|
268
|
+
# a large contiguous buffer, which can fail even with available memory due to fragmentation
|
|
269
|
+
|
|
270
|
+
actor_memory_mb = (total_points * 200) / (1024**2) # Actor data (increased estimate)
|
|
271
|
+
render_memory_mb = (figure_size[0] * figure_size[1] * 4 * 3) / (1024**2) # Render buffer (triple buffering)
|
|
272
|
+
|
|
273
|
+
# Add penalty for very long streamlines (fragmentation risk)
|
|
274
|
+
# If any streamline has >10k points, VTK needs a large contiguous buffer
|
|
275
|
+
fragmentation_penalty_mb = 0.0
|
|
276
|
+
if max_points_per_sl > MAX_POINTS_PER_STREAMLINE_FRAGMENTATION_THRESHOLD:
|
|
277
|
+
# Large contiguous allocation needed - add 50% overhead for fragmentation
|
|
278
|
+
fragmentation_penalty_mb = actor_memory_mb * 0.5
|
|
279
|
+
logger.warning(
|
|
280
|
+
"Very long streamline detected (%d points). VTK may need large contiguous memory allocation.",
|
|
281
|
+
max_points_per_sl,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
overhead_mb = (actor_memory_mb + render_memory_mb) * 0.3 # 30% overhead
|
|
285
|
+
total_mb = actor_memory_mb + render_memory_mb + overhead_mb + fragmentation_penalty_mb
|
|
286
|
+
|
|
287
|
+
logger.info(
|
|
288
|
+
"Memory estimate for actor: %.2f MB (total_points=%d, streamlines=%d, "
|
|
289
|
+
"avg_points/sl=%.1f, max_points/sl=%d, image=%dx%d)",
|
|
290
|
+
total_mb,
|
|
291
|
+
total_points,
|
|
292
|
+
num_streamlines,
|
|
293
|
+
avg_points_per_sl,
|
|
294
|
+
max_points_per_sl,
|
|
295
|
+
figure_size[0],
|
|
296
|
+
figure_size[1],
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
return total_mb
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def _check_memory_available(required_mb: float, safety_margin: float = 0.2) -> bool:
|
|
303
|
+
"""Check if enough memory is available before loading data.
|
|
304
|
+
|
|
305
|
+
Parameters
|
|
306
|
+
----------
|
|
307
|
+
required_mb : float
|
|
308
|
+
Estimated memory required in MB.
|
|
309
|
+
safety_margin : float, default=0.2
|
|
310
|
+
Safety margin as fraction (0.2 = 20% extra buffer).
|
|
311
|
+
|
|
312
|
+
Returns
|
|
313
|
+
-------
|
|
314
|
+
bool
|
|
315
|
+
True if enough memory is available, False otherwise.
|
|
316
|
+
"""
|
|
317
|
+
if psutil is None:
|
|
318
|
+
# If psutil not available, assume we have enough
|
|
319
|
+
logger.debug("psutil not available, skipping memory check")
|
|
320
|
+
return True
|
|
321
|
+
|
|
322
|
+
try:
|
|
323
|
+
# Get system memory
|
|
324
|
+
sys_mem = psutil.virtual_memory()
|
|
325
|
+
available_mb = sys_mem.available / (1024**2)
|
|
326
|
+
|
|
327
|
+
# Get current process memory
|
|
328
|
+
process = psutil.Process(os.getpid())
|
|
329
|
+
current_mb = process.memory_info().rss / (1024**2)
|
|
330
|
+
|
|
331
|
+
# Calculate required with safety margin
|
|
332
|
+
required_with_margin = required_mb * (1 + safety_margin)
|
|
333
|
+
|
|
334
|
+
# Check if we have enough available memory
|
|
335
|
+
if available_mb < required_with_margin:
|
|
336
|
+
logger.warning(
|
|
337
|
+
"Insufficient memory: Need %.2f MB (with %.0f%% margin), "
|
|
338
|
+
"but only %.2f MB available. Current process: %.2f MB",
|
|
339
|
+
required_with_margin,
|
|
340
|
+
safety_margin * 100,
|
|
341
|
+
available_mb,
|
|
342
|
+
current_mb,
|
|
343
|
+
)
|
|
344
|
+
return False
|
|
345
|
+
|
|
346
|
+
logger.debug(
|
|
347
|
+
"Memory check: Need %.2f MB, Available: %.2f MB, Current: %.2f MB",
|
|
348
|
+
required_with_margin,
|
|
349
|
+
available_mb,
|
|
350
|
+
current_mb,
|
|
351
|
+
)
|
|
352
|
+
except (OSError, RuntimeError, AttributeError) as e:
|
|
353
|
+
logger.warning("Memory check failed: %s. Proceeding anyway.", e)
|
|
354
|
+
return True
|
|
355
|
+
else:
|
|
356
|
+
return True
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def _get_n_jobs_with_memory_limit(
|
|
360
|
+
base_n_jobs: int,
|
|
361
|
+
estimated_memory_per_job_mb: float = 2000.0,
|
|
362
|
+
safety_margin: float = 0.2,
|
|
363
|
+
) -> int:
|
|
364
|
+
"""Calculate n_jobs considering available memory.
|
|
365
|
+
|
|
366
|
+
Reduces n_jobs if there isn't enough memory to run all jobs in parallel.
|
|
367
|
+
|
|
368
|
+
Parameters
|
|
369
|
+
----------
|
|
370
|
+
base_n_jobs : int
|
|
371
|
+
Base number of jobs to use (from CPU/SLURM settings).
|
|
372
|
+
estimated_memory_per_job_mb : float, default=2000.0
|
|
373
|
+
Estimated memory per job in MB. Default assumes ~2GB per worker.
|
|
374
|
+
safety_margin : float, default=0.2
|
|
375
|
+
Safety margin as fraction (0.2 = 20% extra buffer).
|
|
376
|
+
|
|
377
|
+
Returns
|
|
378
|
+
-------
|
|
379
|
+
int
|
|
380
|
+
Adjusted number of jobs considering memory constraints.
|
|
381
|
+
"""
|
|
382
|
+
if psutil is None:
|
|
383
|
+
# If psutil not available, return base_n_jobs
|
|
384
|
+
return base_n_jobs
|
|
385
|
+
|
|
386
|
+
try:
|
|
387
|
+
# Get available system memory
|
|
388
|
+
sys_mem = psutil.virtual_memory()
|
|
389
|
+
available_mb = sys_mem.available / (1024**2)
|
|
390
|
+
|
|
391
|
+
# Get current process memory
|
|
392
|
+
process = psutil.Process(os.getpid())
|
|
393
|
+
current_mb = process.memory_info().rss / (1024**2)
|
|
394
|
+
|
|
395
|
+
# Calculate how much memory we can use for workers
|
|
396
|
+
# Reserve some for the main process
|
|
397
|
+
usable_mb = available_mb - (current_mb * 0.5) # Reserve 50% of current for main process
|
|
398
|
+
|
|
399
|
+
# Calculate required memory with safety margin
|
|
400
|
+
required_per_job = estimated_memory_per_job_mb * (1 + safety_margin)
|
|
401
|
+
|
|
402
|
+
# Calculate max jobs based on memory
|
|
403
|
+
max_jobs_by_memory = max(1, int(usable_mb / required_per_job))
|
|
404
|
+
|
|
405
|
+
# Use the minimum of base_n_jobs and memory-limited jobs
|
|
406
|
+
optimal_jobs = min(base_n_jobs, max_jobs_by_memory)
|
|
407
|
+
|
|
408
|
+
if optimal_jobs < base_n_jobs:
|
|
409
|
+
logger.info(
|
|
410
|
+
"Reduced n_jobs from %d to %d due to memory constraints "
|
|
411
|
+
"(Available: %.2f MB, Estimated per job: %.2f MB)",
|
|
412
|
+
base_n_jobs,
|
|
413
|
+
optimal_jobs,
|
|
414
|
+
usable_mb,
|
|
415
|
+
required_per_job,
|
|
416
|
+
)
|
|
417
|
+
except (OSError, RuntimeError, AttributeError) as e:
|
|
418
|
+
logger.warning("Memory-based n_jobs calculation failed: %s. Using base_n_jobs=%d", e, base_n_jobs)
|
|
419
|
+
return base_n_jobs
|
|
420
|
+
else:
|
|
421
|
+
return optimal_jobs
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def _get_optimal_n_jobs() -> int:
|
|
425
|
+
"""Calculate optimal number of jobs considering SLURM and OpenMP settings.
|
|
426
|
+
|
|
427
|
+
This function respects SLURM CPU allocations and OpenMP thread settings
|
|
428
|
+
to prevent oversubscription. It checks:
|
|
429
|
+
1. SLURM_CPUS_PER_TASK (if running under SLURM)
|
|
430
|
+
2. SLURM_JOB_CPUS_PER_NODE (if running under SLURM)
|
|
431
|
+
3. OMP_NUM_THREADS (to avoid conflicts with OpenMP)
|
|
432
|
+
4. Falls back to multiprocessing.cpu_count() if not in SLURM
|
|
433
|
+
|
|
434
|
+
Returns
|
|
435
|
+
-------
|
|
436
|
+
int
|
|
437
|
+
Optimal number of parallel jobs to use.
|
|
438
|
+
"""
|
|
439
|
+
# Check if running under SLURM
|
|
440
|
+
slurm_cpus_per_task = os.environ.get("SLURM_CPUS_PER_TASK")
|
|
441
|
+
slurm_job_cpus = os.environ.get("SLURM_JOB_CPUS_PER_NODE")
|
|
442
|
+
|
|
443
|
+
if slurm_cpus_per_task:
|
|
444
|
+
# Use SLURM allocation
|
|
445
|
+
allocated_cpus = int(slurm_cpus_per_task)
|
|
446
|
+
logger.debug("Using SLURM_CPUS_PER_TASK=%d", allocated_cpus)
|
|
447
|
+
|
|
448
|
+
# Check OMP_NUM_THREADS to avoid oversubscription
|
|
449
|
+
omp_threads = int(os.environ.get("OMP_NUM_THREADS", "1"))
|
|
450
|
+
if omp_threads > 1:
|
|
451
|
+
# If OpenMP uses multiple threads, reduce n_jobs accordingly
|
|
452
|
+
# Formula: n_jobs = allocated_cpus / omp_threads
|
|
453
|
+
optimal_jobs = max(1, allocated_cpus // omp_threads)
|
|
454
|
+
logger.debug(
|
|
455
|
+
"OMP_NUM_THREADS=%d detected. Using n_jobs=%d (allocated_cpus=%d / omp_threads=%d)",
|
|
456
|
+
omp_threads,
|
|
457
|
+
optimal_jobs,
|
|
458
|
+
allocated_cpus,
|
|
459
|
+
omp_threads,
|
|
460
|
+
)
|
|
461
|
+
return optimal_jobs
|
|
462
|
+
return allocated_cpus
|
|
463
|
+
if slurm_job_cpus:
|
|
464
|
+
# Use SLURM job CPUs (may be a range like "8-16", take the first value)
|
|
465
|
+
allocated_cpus = int(slurm_job_cpus.split("-")[0].split(",")[0])
|
|
466
|
+
logger.debug("Using SLURM_JOB_CPUS_PER_NODE=%d", allocated_cpus)
|
|
467
|
+
|
|
468
|
+
# Check OMP_NUM_THREADS
|
|
469
|
+
omp_threads = int(os.environ.get("OMP_NUM_THREADS", "1"))
|
|
470
|
+
if omp_threads > 1:
|
|
471
|
+
optimal_jobs = max(1, allocated_cpus // omp_threads)
|
|
472
|
+
logger.debug(
|
|
473
|
+
"OMP_NUM_THREADS=%d detected. Using n_jobs=%d",
|
|
474
|
+
omp_threads,
|
|
475
|
+
optimal_jobs,
|
|
476
|
+
)
|
|
477
|
+
return optimal_jobs
|
|
478
|
+
return allocated_cpus
|
|
479
|
+
|
|
480
|
+
# Not in SLURM, check OMP_NUM_THREADS and fall back to cpu_count()
|
|
481
|
+
omp_threads = int(os.environ.get("OMP_NUM_THREADS", "0"))
|
|
482
|
+
if omp_threads > 0:
|
|
483
|
+
# If OMP_NUM_THREADS is set, use it to calculate n_jobs
|
|
484
|
+
total_cpus = multiprocessing.cpu_count()
|
|
485
|
+
optimal_jobs = max(1, total_cpus // omp_threads)
|
|
486
|
+
logger.debug(
|
|
487
|
+
"OMP_NUM_THREADS=%d detected. Using n_jobs=%d (cpu_count=%d / omp_threads=%d)",
|
|
488
|
+
omp_threads,
|
|
489
|
+
optimal_jobs,
|
|
490
|
+
total_cpus,
|
|
491
|
+
omp_threads,
|
|
492
|
+
)
|
|
493
|
+
return optimal_jobs
|
|
494
|
+
|
|
495
|
+
# Default: use all available CPUs
|
|
496
|
+
total_cpus = multiprocessing.cpu_count()
|
|
497
|
+
logger.debug("Using all available CPUs: n_jobs=%d", total_cpus)
|
|
498
|
+
return total_cpus
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def _process_tract_worker(
|
|
502
|
+
subject_id: str,
|
|
503
|
+
tract_name: str,
|
|
504
|
+
tract_file: str | Path,
|
|
505
|
+
subject_ref_img: str | Path,
|
|
506
|
+
tract_output_dir: str | Path,
|
|
507
|
+
subjects_mni_space: dict[str, dict[str, str | Path]] | None,
|
|
508
|
+
atlas_files: dict[str, str | Path] | None,
|
|
509
|
+
metric_files: dict[str, dict[str, str | Path]] | None,
|
|
510
|
+
atlas_ref_img: str | Path | None,
|
|
511
|
+
*,
|
|
512
|
+
flip_lr: bool,
|
|
513
|
+
skip_checks: list[str],
|
|
514
|
+
visualizer_params: dict[str, Any],
|
|
515
|
+
**kwargs: Any,
|
|
516
|
+
) -> tuple[str, str, dict[str, str | Path]]:
|
|
517
|
+
"""Worker function for parallel processing of tracts.
|
|
518
|
+
|
|
519
|
+
This function creates a new visualizer instance and processes a single tract.
|
|
520
|
+
Returns (subject_id, tract_name, results_dict).
|
|
521
|
+
"""
|
|
522
|
+
visualizer = None
|
|
523
|
+
results: dict[str, str | Path] = {}
|
|
524
|
+
|
|
525
|
+
try:
|
|
526
|
+
# Initialize VTK offscreen rendering in worker process
|
|
527
|
+
# This is critical for headless cluster environments
|
|
528
|
+
# Set environment variables before any VTK operations
|
|
529
|
+
os.environ.setdefault("VTK_STREAM_READER", "1")
|
|
530
|
+
os.environ.setdefault("VTK_STREAM_WRITER", "1")
|
|
531
|
+
# Force offscreen rendering to prevent segfaults
|
|
532
|
+
os.environ.setdefault("VTK_USE_OFFSCREEN", "1")
|
|
533
|
+
# Disable OpenGL to prevent segfaults in headless environments
|
|
534
|
+
os.environ.setdefault("VTK_USE_OSMESA", "0")
|
|
535
|
+
|
|
536
|
+
# Disable VTK warnings and errors that might cause issues
|
|
537
|
+
with contextlib.suppress(AttributeError, RuntimeError, ImportError):
|
|
538
|
+
vtk.vtkObject.GlobalWarningDisplayOff()
|
|
539
|
+
# Try to set error callback to prevent crashes
|
|
540
|
+
with contextlib.suppress(AttributeError, RuntimeError):
|
|
541
|
+
vtk.vtkOutputWindow.SetGlobalWarningDisplay(0)
|
|
542
|
+
|
|
543
|
+
# Set matplotlib to non-interactive backend for headless environments
|
|
544
|
+
with contextlib.suppress(ImportError, ValueError, RuntimeError):
|
|
545
|
+
# Use Agg backend (no display needed) for headless environments
|
|
546
|
+
plt.switch_backend("Agg")
|
|
547
|
+
|
|
548
|
+
# Create a new visualizer instance in this worker process
|
|
549
|
+
# Wrap in try-except to catch initialization errors that might cause segfaults
|
|
550
|
+
try:
|
|
551
|
+
visualizer = TractographyVisualizer(**visualizer_params)
|
|
552
|
+
except (OSError, RuntimeError, MemoryError) as e:
|
|
553
|
+
logger.exception(
|
|
554
|
+
"Failed to initialize visualizer in worker for %s/%s: %s",
|
|
555
|
+
subject_id,
|
|
556
|
+
tract_name,
|
|
557
|
+
type(e).__name__,
|
|
558
|
+
)
|
|
559
|
+
return (subject_id, tract_name, {})
|
|
560
|
+
|
|
561
|
+
# Process the tract with additional error handling
|
|
562
|
+
try:
|
|
563
|
+
results = visualizer._process_single_tract(
|
|
564
|
+
subject_id=subject_id,
|
|
565
|
+
tract_name=tract_name,
|
|
566
|
+
tract_file=tract_file,
|
|
567
|
+
subject_ref_img=Path(subject_ref_img),
|
|
568
|
+
tract_output_dir=Path(tract_output_dir),
|
|
569
|
+
subjects_mni_space=subjects_mni_space,
|
|
570
|
+
atlas_files=atlas_files,
|
|
571
|
+
metric_files=metric_files,
|
|
572
|
+
atlas_ref_img=atlas_ref_img,
|
|
573
|
+
flip_lr=flip_lr,
|
|
574
|
+
skip_checks=skip_checks,
|
|
575
|
+
**kwargs,
|
|
576
|
+
)
|
|
577
|
+
except MemoryError:
|
|
578
|
+
# Memory errors are critical - log and return empty results
|
|
579
|
+
logger.exception(
|
|
580
|
+
"Memory error in worker process for %s/%s. Consider reducing n_jobs or increasing memory allocation.",
|
|
581
|
+
subject_id,
|
|
582
|
+
tract_name,
|
|
583
|
+
)
|
|
584
|
+
results = {}
|
|
585
|
+
except (OSError, ValueError, RuntimeError) as e:
|
|
586
|
+
# Catch specific exceptions that might cause process crashes
|
|
587
|
+
# Log the error but don't re-raise to prevent process pool from breaking
|
|
588
|
+
logger.exception(
|
|
589
|
+
"Error in worker process for %s/%s (%s)",
|
|
590
|
+
subject_id,
|
|
591
|
+
tract_name,
|
|
592
|
+
type(e).__name__,
|
|
593
|
+
)
|
|
594
|
+
# Return empty results dict on error
|
|
595
|
+
results = {}
|
|
596
|
+
except Exception as e:
|
|
597
|
+
# Catch any other unexpected exceptions
|
|
598
|
+
logger.exception(
|
|
599
|
+
"Unexpected error in worker process for %s/%s (%s)",
|
|
600
|
+
subject_id,
|
|
601
|
+
tract_name,
|
|
602
|
+
type(e).__name__,
|
|
603
|
+
)
|
|
604
|
+
# Return empty results dict on error
|
|
605
|
+
results = {}
|
|
606
|
+
except (OSError, ValueError, RuntimeError, MemoryError) as e:
|
|
607
|
+
# Catch errors during initialization or setup
|
|
608
|
+
logger.exception(
|
|
609
|
+
"Error during worker initialization for %s/%s (%s)",
|
|
610
|
+
subject_id,
|
|
611
|
+
tract_name,
|
|
612
|
+
type(e).__name__,
|
|
613
|
+
)
|
|
614
|
+
results = {}
|
|
615
|
+
except Exception as e:
|
|
616
|
+
# Catch any other unexpected exceptions during setup
|
|
617
|
+
logger.exception(
|
|
618
|
+
"Unexpected error during worker setup for %s/%s (%s)",
|
|
619
|
+
subject_id,
|
|
620
|
+
tract_name,
|
|
621
|
+
type(e).__name__,
|
|
622
|
+
)
|
|
623
|
+
results = {}
|
|
624
|
+
finally:
|
|
625
|
+
# Clean up the visualizer instance and force garbage collection
|
|
626
|
+
# This is critical to prevent memory leaks that could cause OOM kills
|
|
627
|
+
if visualizer is not None:
|
|
628
|
+
# Try to clean up any VTK objects
|
|
629
|
+
with contextlib.suppress(AttributeError, RuntimeError, TypeError):
|
|
630
|
+
del visualizer
|
|
631
|
+
# Clear results dict reference if it's large
|
|
632
|
+
if results:
|
|
633
|
+
# Keep only the minimal results (file paths, not data)
|
|
634
|
+
pass # Results dict should only contain paths, not large data
|
|
635
|
+
|
|
636
|
+
# Force garbage collection multiple times to handle circular references
|
|
637
|
+
# VTK objects can have complex reference cycles
|
|
638
|
+
for _ in range(3):
|
|
639
|
+
gc.collect()
|
|
640
|
+
|
|
641
|
+
return (subject_id, tract_name, results)
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
class TractographyVisualizationError(Exception):
|
|
645
|
+
"""Base exception for tractography visualization errors."""
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
class InvalidInputError(TractographyVisualizationError):
|
|
649
|
+
"""Raised when input data is invalid."""
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
class TractographyVisualizer:
|
|
653
|
+
"""A class for visualizing diffusion tractography data.
|
|
654
|
+
|
|
655
|
+
This class provides methods for loading, processing, and visualizing
|
|
656
|
+
tractography data, including generating static images, animations, and
|
|
657
|
+
quality metrics.
|
|
658
|
+
|
|
659
|
+
Parameters
|
|
660
|
+
----------
|
|
661
|
+
reference_image : str | Path, optional
|
|
662
|
+
Path to the reference T1-weighted image. Can be set later via
|
|
663
|
+
`set_reference_image()`.
|
|
664
|
+
output_directory : str | Path, optional
|
|
665
|
+
Default output directory for generated files. Can be set later via
|
|
666
|
+
`set_output_directory()`.
|
|
667
|
+
gif_size : tuple[int, int], optional
|
|
668
|
+
Size of generated GIFs in pixels. Default is (608, 608).
|
|
669
|
+
gif_duration : float, optional
|
|
670
|
+
Duration per frame in seconds. Default is 0.2.
|
|
671
|
+
gif_palette_size : int, optional
|
|
672
|
+
Color palette size for GIF optimization. Default is 64.
|
|
673
|
+
gif_frames : int, optional
|
|
674
|
+
Number of frames in rotation animation. Default is 60.
|
|
675
|
+
min_streamline_length : float, optional
|
|
676
|
+
Minimum streamline length for CCI calculation. Default is 40.0.
|
|
677
|
+
cci_threshold : float, optional
|
|
678
|
+
Minimum CCI value to keep streamlines. Default is 1.0.
|
|
679
|
+
afq_resample_points : int, optional
|
|
680
|
+
Number of points for AFQ resampling. Default is 100.
|
|
681
|
+
n_jobs : int, optional
|
|
682
|
+
Number of parallel jobs to run for processing multiple subjects/tracts.
|
|
683
|
+
Default is 1 (sequential processing). Use -1 to automatically determine
|
|
684
|
+
optimal number based on available resources (respects SLURM allocations
|
|
685
|
+
and OpenMP thread settings to prevent oversubscription).
|
|
686
|
+
Only used in `run_quality_check_workflow()`.
|
|
687
|
+
|
|
688
|
+
Note: When running under SLURM, this will automatically use
|
|
689
|
+
SLURM_CPUS_PER_TASK or SLURM_JOB_CPUS_PER_NODE. If OMP_NUM_THREADS is set,
|
|
690
|
+
it will divide the available CPUs by the number of OpenMP threads to
|
|
691
|
+
prevent resource contention.
|
|
692
|
+
max_memory_mb : float | None, optional
|
|
693
|
+
Maximum memory limit in MB for the process. If set, the process will be
|
|
694
|
+
killed by the OS if it exceeds this limit. This helps prevent OOM kills
|
|
695
|
+
by setting a hard limit. Default is None (no limit).
|
|
696
|
+
|
|
697
|
+
Note: This uses resource.setrlimit() which sets a virtual memory limit.
|
|
698
|
+
Once set, the limit cannot be increased (only decreased).
|
|
699
|
+
|
|
700
|
+
Examples
|
|
701
|
+
--------
|
|
702
|
+
Single subject usage:
|
|
703
|
+
>>> visualizer = TractographyVisualizer(
|
|
704
|
+
... reference_image="path/to/t1w.nii.gz", output_directory="output/"
|
|
705
|
+
... )
|
|
706
|
+
>>> visualizer.generate_videos(
|
|
707
|
+
... tract_files=["tract1.trk", "tract2.trk"], ref_file="t1w.nii.gz"
|
|
708
|
+
... )
|
|
709
|
+
|
|
710
|
+
Multiple subjects usage (initialize once, set data per subject):
|
|
711
|
+
>>> visualizer = TractographyVisualizer(output_directory="output/")
|
|
712
|
+
>>> # Process subject 1
|
|
713
|
+
>>> visualizer.generate_videos(
|
|
714
|
+
... tract_files=["subj1_tract1.trk"], ref_file="subj1_t1w.nii.gz"
|
|
715
|
+
... )
|
|
716
|
+
>>> # Process subject 2
|
|
717
|
+
>>> visualizer.generate_videos(
|
|
718
|
+
... tract_files=["subj2_tract1.trk"], ref_file="subj2_t1w.nii.gz"
|
|
719
|
+
... )
|
|
720
|
+
"""
|
|
721
|
+
|
|
722
|
+
def __init__(
|
|
723
|
+
self,
|
|
724
|
+
reference_image: str | Path | None = None,
|
|
725
|
+
output_directory: str | Path | None = None,
|
|
726
|
+
*,
|
|
727
|
+
gif_size: tuple[int, int] = (608, 608),
|
|
728
|
+
gif_duration: float = 0.2,
|
|
729
|
+
gif_palette_size: int = 64,
|
|
730
|
+
gif_frames: int = 60,
|
|
731
|
+
min_streamline_length: float = 40.0,
|
|
732
|
+
cci_threshold: float = 1.0,
|
|
733
|
+
afq_resample_points: int = 100,
|
|
734
|
+
n_jobs: int = 1,
|
|
735
|
+
max_memory_mb: float | None = None,
|
|
736
|
+
) -> None:
|
|
737
|
+
"""Initialize the TractographyVisualizer."""
|
|
738
|
+
self._reference_image: Path | None = None
|
|
739
|
+
self._output_directory: Path | None = None
|
|
740
|
+
self.max_memory_mb: float | None = None
|
|
741
|
+
|
|
742
|
+
if reference_image is not None:
|
|
743
|
+
self.set_reference_image(reference_image)
|
|
744
|
+
if output_directory is not None:
|
|
745
|
+
self.set_output_directory(output_directory)
|
|
746
|
+
|
|
747
|
+
self.gif_size = gif_size
|
|
748
|
+
self.gif_duration = gif_duration
|
|
749
|
+
self.gif_palette_size = gif_palette_size
|
|
750
|
+
self.gif_frames = gif_frames
|
|
751
|
+
self.min_streamline_length = min_streamline_length
|
|
752
|
+
self.cci_threshold = cci_threshold
|
|
753
|
+
self.afq_resample_points = afq_resample_points
|
|
754
|
+
# Handle n_jobs: -1 means use all CPUs, otherwise use specified value
|
|
755
|
+
if n_jobs == -1:
|
|
756
|
+
self.n_jobs = _get_optimal_n_jobs()
|
|
757
|
+
else:
|
|
758
|
+
self.n_jobs = n_jobs
|
|
759
|
+
|
|
760
|
+
# Set memory limit if specified
|
|
761
|
+
if max_memory_mb is not None:
|
|
762
|
+
_set_memory_limit(max_memory_mb)
|
|
763
|
+
self.max_memory_mb = max_memory_mb
|
|
764
|
+
else:
|
|
765
|
+
self.max_memory_mb = None
|
|
766
|
+
|
|
767
|
+
def set_reference_image(self, reference_image: str | Path) -> None:
|
|
768
|
+
"""Set the reference T1-weighted image.
|
|
769
|
+
|
|
770
|
+
Parameters
|
|
771
|
+
----------
|
|
772
|
+
reference_image : str | Path
|
|
773
|
+
Path to the reference image file.
|
|
774
|
+
|
|
775
|
+
Raises
|
|
776
|
+
------
|
|
777
|
+
FileNotFoundError
|
|
778
|
+
If the reference image file does not exist.
|
|
779
|
+
"""
|
|
780
|
+
ref_path = Path(reference_image)
|
|
781
|
+
if not ref_path.exists():
|
|
782
|
+
raise FileNotFoundError(f"Reference image not found: {reference_image}")
|
|
783
|
+
self._reference_image = ref_path
|
|
784
|
+
|
|
785
|
+
def set_output_directory(self, output_directory: str | Path) -> None:
|
|
786
|
+
"""Set the output directory for generated files.
|
|
787
|
+
|
|
788
|
+
Parameters
|
|
789
|
+
----------
|
|
790
|
+
output_directory : str | Path
|
|
791
|
+
Path to the output directory. Will be created if it doesn't exist.
|
|
792
|
+
|
|
793
|
+
Raises
|
|
794
|
+
------
|
|
795
|
+
OSError
|
|
796
|
+
If the directory cannot be created.
|
|
797
|
+
"""
|
|
798
|
+
out_dir = Path(output_directory)
|
|
799
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
800
|
+
self._output_directory = out_dir
|
|
801
|
+
|
|
802
|
+
@property
|
|
803
|
+
def reference_image(self) -> Path | None:
|
|
804
|
+
"""Get the current reference image path."""
|
|
805
|
+
return self._reference_image
|
|
806
|
+
|
|
807
|
+
@property
|
|
808
|
+
def output_directory(self) -> Path | None:
|
|
809
|
+
"""Get the current output directory path."""
|
|
810
|
+
return self._output_directory
|
|
811
|
+
|
|
812
|
+
def get_glass_brain(self, t1w_img: str | Path | None = None) -> actor.Actor:
|
|
813
|
+
"""Get the glass brain actor for the T1-weighted image.
|
|
814
|
+
|
|
815
|
+
Parameters
|
|
816
|
+
----------
|
|
817
|
+
t1w_img : str | Path | None, optional
|
|
818
|
+
Path to the T1-weighted image. If None, uses the reference image
|
|
819
|
+
set during initialization or via `set_reference_image()`.
|
|
820
|
+
|
|
821
|
+
Returns
|
|
822
|
+
-------
|
|
823
|
+
actor.Actor
|
|
824
|
+
The glass brain actor.
|
|
825
|
+
|
|
826
|
+
Raises
|
|
827
|
+
------
|
|
828
|
+
FileNotFoundError
|
|
829
|
+
If the image file is not found.
|
|
830
|
+
InvalidInputError
|
|
831
|
+
If no reference image is provided and none was set.
|
|
832
|
+
"""
|
|
833
|
+
if t1w_img is None:
|
|
834
|
+
if self._reference_image is None:
|
|
835
|
+
raise InvalidInputError(
|
|
836
|
+
"No reference image provided. Set it via constructor or "
|
|
837
|
+
"set_reference_image() method, or pass it as an argument.",
|
|
838
|
+
)
|
|
839
|
+
t1w_img = self._reference_image
|
|
840
|
+
|
|
841
|
+
try:
|
|
842
|
+
mask_img = nib.load(str(t1w_img))
|
|
843
|
+
mask_data = mask_img.get_fdata() # type: ignore[attr-defined]
|
|
844
|
+
except (OSError, ValueError, RuntimeError) as e:
|
|
845
|
+
raise TractographyVisualizationError(
|
|
846
|
+
f"Failed to load glass brain from {t1w_img}: {e}",
|
|
847
|
+
) from e
|
|
848
|
+
else:
|
|
849
|
+
return actor.contour_from_roi(mask_data, color=[0, 0, 0], opacity=0.05)
|
|
850
|
+
|
|
851
|
+
def _set_anatomical_camera(
|
|
852
|
+
self,
|
|
853
|
+
scene: window.Scene,
|
|
854
|
+
centroid: np.ndarray,
|
|
855
|
+
view_name: str,
|
|
856
|
+
*,
|
|
857
|
+
camera_distance: float | None = None,
|
|
858
|
+
bbox_size: np.ndarray | None = None,
|
|
859
|
+
) -> None:
|
|
860
|
+
"""Set camera position for standard anatomical views.
|
|
861
|
+
|
|
862
|
+
Wrapper around utils.set_anatomical_camera for class method compatibility.
|
|
863
|
+
|
|
864
|
+
Parameters
|
|
865
|
+
----------
|
|
866
|
+
scene : window.Scene
|
|
867
|
+
The FURY scene to set the camera on.
|
|
868
|
+
centroid : np.ndarray
|
|
869
|
+
The centroid of the streamlines (3D coordinates).
|
|
870
|
+
view_name : str
|
|
871
|
+
Name of the view: "coronal", "axial", or "sagittal".
|
|
872
|
+
camera_distance : float | None, optional
|
|
873
|
+
Distance of camera from centroid. If None, calculated from bbox_size.
|
|
874
|
+
bbox_size : np.ndarray | None, optional
|
|
875
|
+
Bounding box size of streamlines. Used to calculate camera_distance if not provided.
|
|
876
|
+
|
|
877
|
+
Raises
|
|
878
|
+
------
|
|
879
|
+
InvalidInputError
|
|
880
|
+
If view_name is not one of the standard anatomical views.
|
|
881
|
+
"""
|
|
882
|
+
try:
|
|
883
|
+
set_anatomical_camera(
|
|
884
|
+
scene,
|
|
885
|
+
centroid,
|
|
886
|
+
view_name,
|
|
887
|
+
camera_distance=camera_distance,
|
|
888
|
+
bbox_size=bbox_size,
|
|
889
|
+
)
|
|
890
|
+
except ValueError as e:
|
|
891
|
+
raise InvalidInputError(str(e)) from e
|
|
892
|
+
|
|
893
|
+
def _create_scene(
|
|
894
|
+
self,
|
|
895
|
+
*,
|
|
896
|
+
ref_img: str | Path | None = None,
|
|
897
|
+
show_glass_brain: bool = True,
|
|
898
|
+
) -> tuple[window.Scene, actor.Actor | None]:
|
|
899
|
+
"""Create a new scene with optional glass brain.
|
|
900
|
+
|
|
901
|
+
Parameters
|
|
902
|
+
----------
|
|
903
|
+
show_glass_brain : bool, optional
|
|
904
|
+
Whether to add glass brain to the scene. Default is True.
|
|
905
|
+
ref_img : str | Path | None, optional
|
|
906
|
+
Reference image for glass brain. If None, uses default.
|
|
907
|
+
|
|
908
|
+
Returns
|
|
909
|
+
-------
|
|
910
|
+
tuple[window.Scene, actor.Actor | None]
|
|
911
|
+
The scene and brain actor (if added, otherwise None).
|
|
912
|
+
"""
|
|
913
|
+
scene = window.Scene()
|
|
914
|
+
scene.SetBackground(1, 1, 1)
|
|
915
|
+
|
|
916
|
+
brain_actor = None
|
|
917
|
+
if show_glass_brain:
|
|
918
|
+
brain_actor = self.get_glass_brain(ref_img)
|
|
919
|
+
scene.add(brain_actor)
|
|
920
|
+
|
|
921
|
+
return scene, brain_actor
|
|
922
|
+
|
|
923
|
+
def _create_streamline_actor(
|
|
924
|
+
self,
|
|
925
|
+
streamlines: Streamlines,
|
|
926
|
+
colors: np.ndarray | None = None,
|
|
927
|
+
) -> actor.Actor:
|
|
928
|
+
"""Create a streamline actor with optional colors, handling length mismatches.
|
|
929
|
+
|
|
930
|
+
Parameters
|
|
931
|
+
----------
|
|
932
|
+
streamlines : Streamlines
|
|
933
|
+
The streamlines to visualize.
|
|
934
|
+
colors : np.ndarray | None, optional
|
|
935
|
+
Colors array. If provided, must match streamlines length or will be adjusted.
|
|
936
|
+
|
|
937
|
+
Returns
|
|
938
|
+
-------
|
|
939
|
+
actor.Actor
|
|
940
|
+
The streamline actor.
|
|
941
|
+
"""
|
|
942
|
+
if colors is not None:
|
|
943
|
+
if len(colors) == len(streamlines):
|
|
944
|
+
return actor.line(streamlines, colors=colors)
|
|
945
|
+
# Handle length mismatch
|
|
946
|
+
if len(colors) < len(streamlines):
|
|
947
|
+
repeat_factor = len(streamlines) // len(colors) + 1
|
|
948
|
+
extended_colors = np.tile(colors, (repeat_factor, 1))[: len(streamlines)]
|
|
949
|
+
return actor.line(streamlines, colors=extended_colors)
|
|
950
|
+
return actor.line(streamlines, colors=colors[: len(streamlines)])
|
|
951
|
+
return actor.line(streamlines)
|
|
952
|
+
|
|
953
|
+
def _extract_tract_name(self, tract_file: str | Path) -> str:
|
|
954
|
+
"""Extract tract name from file path (without extension).
|
|
955
|
+
|
|
956
|
+
Parameters
|
|
957
|
+
----------
|
|
958
|
+
tract_file : str | Path
|
|
959
|
+
Path to tractography file.
|
|
960
|
+
|
|
961
|
+
Returns
|
|
962
|
+
-------
|
|
963
|
+
str
|
|
964
|
+
Tract name without extension.
|
|
965
|
+
"""
|
|
966
|
+
return Path(tract_file).stem
|
|
967
|
+
|
|
968
|
+
def load_tract(
|
|
969
|
+
self,
|
|
970
|
+
tract_file: str | Path,
|
|
971
|
+
ref_img: str | Path | None = None,
|
|
972
|
+
) -> actor.Actor:
|
|
973
|
+
"""Load the tractography file and create an actor.
|
|
974
|
+
|
|
975
|
+
Parameters
|
|
976
|
+
----------
|
|
977
|
+
tract_file : str | Path
|
|
978
|
+
Path to the tractography file (.trk).
|
|
979
|
+
ref_img : str | Path | None, optional
|
|
980
|
+
Path to the reference image. If None, uses the reference image
|
|
981
|
+
set during initialization or via `set_reference_image()`.
|
|
982
|
+
|
|
983
|
+
Returns
|
|
984
|
+
-------
|
|
985
|
+
actor.Actor
|
|
986
|
+
The tractography actor.
|
|
987
|
+
|
|
988
|
+
Raises
|
|
989
|
+
------
|
|
990
|
+
FileNotFoundError
|
|
991
|
+
If the tract or reference image file is not found.
|
|
992
|
+
InvalidInputError
|
|
993
|
+
If no reference image is provided and none was set.
|
|
994
|
+
TractographyVisualizationError
|
|
995
|
+
If loading or transformation fails.
|
|
996
|
+
"""
|
|
997
|
+
if ref_img is None:
|
|
998
|
+
if self._reference_image is None:
|
|
999
|
+
raise InvalidInputError(
|
|
1000
|
+
"No reference image provided. Set it via constructor or "
|
|
1001
|
+
"set_reference_image() method, or pass it as an argument.",
|
|
1002
|
+
)
|
|
1003
|
+
ref_img = self._reference_image
|
|
1004
|
+
|
|
1005
|
+
try:
|
|
1006
|
+
tract = load_trk(str(tract_file), "same", bbox_valid_check=False)
|
|
1007
|
+
tract.to_rasmm()
|
|
1008
|
+
ref_img_obj = nib.load(str(ref_img))
|
|
1009
|
+
tract_to_ref = transform_streamlines(
|
|
1010
|
+
tract.streamlines,
|
|
1011
|
+
np.linalg.inv(ref_img_obj.affine), # type: ignore[attr-defined] # type: ignore[attr-defined]
|
|
1012
|
+
)
|
|
1013
|
+
except (OSError, ValueError, RuntimeError) as e:
|
|
1014
|
+
raise TractographyVisualizationError(
|
|
1015
|
+
f"Failed to load tract from {tract_file}: {e}",
|
|
1016
|
+
) from e
|
|
1017
|
+
else:
|
|
1018
|
+
return actor.line(tract_to_ref)
|
|
1019
|
+
|
|
1020
|
+
def weighted_afq(
|
|
1021
|
+
self,
|
|
1022
|
+
tract_file: str | Path,
|
|
1023
|
+
atlas_file: str | Path,
|
|
1024
|
+
metric_file: str | Path,
|
|
1025
|
+
) -> np.ndarray:
|
|
1026
|
+
"""Calculate weighted AFQ profile for tractography.
|
|
1027
|
+
|
|
1028
|
+
Parameters
|
|
1029
|
+
----------
|
|
1030
|
+
tract_file : str | Path
|
|
1031
|
+
Path to the tractography file.
|
|
1032
|
+
atlas_file : str | Path
|
|
1033
|
+
Path to the atlas tractography file.
|
|
1034
|
+
metric_file : str | Path
|
|
1035
|
+
Path to the metric image file.
|
|
1036
|
+
|
|
1037
|
+
Returns
|
|
1038
|
+
-------
|
|
1039
|
+
np.ndarray
|
|
1040
|
+
The AFQ profile array.
|
|
1041
|
+
|
|
1042
|
+
Raises
|
|
1043
|
+
------
|
|
1044
|
+
FileNotFoundError
|
|
1045
|
+
If any required file is not found.
|
|
1046
|
+
InvalidInputError
|
|
1047
|
+
If clustering fails or no centroids are found.
|
|
1048
|
+
TractographyVisualizationError
|
|
1049
|
+
If processing fails.
|
|
1050
|
+
"""
|
|
1051
|
+
try:
|
|
1052
|
+
atlas_tract = load_trk(str(atlas_file), "same", bbox_valid_check=False)
|
|
1053
|
+
tract = load_trk(str(tract_file), "same", bbox_valid_check=False)
|
|
1054
|
+
metric_img = nib.load(str(metric_file))
|
|
1055
|
+
metric = metric_img.get_fdata() # type: ignore[attr-defined]
|
|
1056
|
+
|
|
1057
|
+
feature = ResampleFeature(nb_points=self.afq_resample_points)
|
|
1058
|
+
qb_metric = AveragePointwiseEuclideanMetric(feature)
|
|
1059
|
+
qb = QuickBundles(threshold=np.inf, metric=qb_metric)
|
|
1060
|
+
cluster_tract = qb.cluster(atlas_tract.streamlines)
|
|
1061
|
+
|
|
1062
|
+
if len(cluster_tract.centroids) == 0:
|
|
1063
|
+
raise InvalidInputError(
|
|
1064
|
+
"No centroids found in atlas tractography clustering.",
|
|
1065
|
+
)
|
|
1066
|
+
|
|
1067
|
+
standard_tract = cluster_tract.centroids[0]
|
|
1068
|
+
oriented_tract = orient_by_streamline(tract.streamlines, standard_tract)
|
|
1069
|
+
w_tract = gaussian_weights(oriented_tract)
|
|
1070
|
+
profile = afq_profile(
|
|
1071
|
+
metric,
|
|
1072
|
+
oriented_tract,
|
|
1073
|
+
affine=metric_img.affine, # type: ignore[attr-defined]
|
|
1074
|
+
weights=w_tract,
|
|
1075
|
+
)
|
|
1076
|
+
|
|
1077
|
+
except TractographyVisualizationError:
|
|
1078
|
+
raise
|
|
1079
|
+
except (OSError, ValueError, RuntimeError, IndexError) as e:
|
|
1080
|
+
raise TractographyVisualizationError(
|
|
1081
|
+
f"Failed to calculate weighted AFQ profile: {e}",
|
|
1082
|
+
) from e
|
|
1083
|
+
else:
|
|
1084
|
+
return profile
|
|
1085
|
+
|
|
1086
|
+
def calc_cci(
|
|
1087
|
+
self,
|
|
1088
|
+
tract: StatefulTractogram,
|
|
1089
|
+
ref_img: str | Path | None = None,
|
|
1090
|
+
) -> tuple[np.ndarray, StatefulTractogram]:
|
|
1091
|
+
"""Calculate Cluster Confidence Index (CCI) for tractography.
|
|
1092
|
+
|
|
1093
|
+
Parameters
|
|
1094
|
+
----------
|
|
1095
|
+
tract : StatefulTractogram
|
|
1096
|
+
The tractography data.
|
|
1097
|
+
ref_img : str | Path | None, optional
|
|
1098
|
+
Path to the reference image. If None, uses the reference image
|
|
1099
|
+
set during initialization or via `set_reference_image()`.
|
|
1100
|
+
|
|
1101
|
+
Returns
|
|
1102
|
+
-------
|
|
1103
|
+
tuple[np.ndarray, StatefulTractogram]
|
|
1104
|
+
A tuple containing:
|
|
1105
|
+
- CCI values as a numpy array
|
|
1106
|
+
- Filtered tractogram with streamlines above threshold
|
|
1107
|
+
|
|
1108
|
+
Raises
|
|
1109
|
+
------
|
|
1110
|
+
InvalidInputError
|
|
1111
|
+
If the tractogram is empty or processing fails, or if no reference
|
|
1112
|
+
image is provided and none was set.
|
|
1113
|
+
TractographyVisualizationError
|
|
1114
|
+
If calculation fails.
|
|
1115
|
+
"""
|
|
1116
|
+
if not tract.streamlines or len(tract.streamlines) == 0:
|
|
1117
|
+
raise InvalidInputError("Tractogram is empty.")
|
|
1118
|
+
|
|
1119
|
+
if ref_img is None:
|
|
1120
|
+
if self._reference_image is None:
|
|
1121
|
+
raise InvalidInputError(
|
|
1122
|
+
"No reference image provided. Set it via constructor or "
|
|
1123
|
+
"set_reference_image() method, or pass it as an argument.",
|
|
1124
|
+
)
|
|
1125
|
+
ref_img = self._reference_image
|
|
1126
|
+
|
|
1127
|
+
try:
|
|
1128
|
+
lengths = list(length(tract.streamlines))
|
|
1129
|
+
long_streamlines = Streamlines()
|
|
1130
|
+
for i, sl in enumerate(tract.streamlines):
|
|
1131
|
+
if lengths[i] > self.min_streamline_length:
|
|
1132
|
+
long_streamlines.append(sl)
|
|
1133
|
+
|
|
1134
|
+
if len(long_streamlines) == 0:
|
|
1135
|
+
raise InvalidInputError(
|
|
1136
|
+
f"No streamlines longer than {self.min_streamline_length}mm found.",
|
|
1137
|
+
)
|
|
1138
|
+
|
|
1139
|
+
cci = cluster_confidence(long_streamlines)
|
|
1140
|
+
|
|
1141
|
+
# Create boolean mask for streamlines above threshold
|
|
1142
|
+
keep_mask = cci >= self.cci_threshold
|
|
1143
|
+
|
|
1144
|
+
# Filter streamlines and CCI values to match
|
|
1145
|
+
keep_streamlines = Streamlines()
|
|
1146
|
+
for i, sl in enumerate(long_streamlines):
|
|
1147
|
+
if keep_mask[i]:
|
|
1148
|
+
keep_streamlines.append(sl)
|
|
1149
|
+
|
|
1150
|
+
# Filter CCI array to match kept streamlines
|
|
1151
|
+
keep_cci = cci[keep_mask]
|
|
1152
|
+
keep_tract = StatefulTractogram(keep_streamlines, nib.load(str(ref_img)), Space.RASMM)
|
|
1153
|
+
|
|
1154
|
+
except TractographyVisualizationError:
|
|
1155
|
+
raise
|
|
1156
|
+
except (OSError, ValueError, RuntimeError, IndexError) as e:
|
|
1157
|
+
raise TractographyVisualizationError(
|
|
1158
|
+
f"Failed to calculate CCI: {e}",
|
|
1159
|
+
) from e
|
|
1160
|
+
else:
|
|
1161
|
+
return keep_cci, keep_tract
|
|
1162
|
+
|
|
1163
|
+
def generate_anatomical_views(
|
|
1164
|
+
self,
|
|
1165
|
+
tract_file: str | Path,
|
|
1166
|
+
ref_img: str | Path | None = None,
|
|
1167
|
+
*,
|
|
1168
|
+
views: list[str] | None = None,
|
|
1169
|
+
output_dir: str | Path | None = None,
|
|
1170
|
+
figure_size: tuple[int, int] = (800, 800),
|
|
1171
|
+
show_glass_brain: bool = True,
|
|
1172
|
+
) -> dict[str, Path]:
|
|
1173
|
+
"""Generate snapshots from standard anatomical views.
|
|
1174
|
+
|
|
1175
|
+
Creates static images from coronal, axial, and sagittal views of the tract.
|
|
1176
|
+
|
|
1177
|
+
Parameters
|
|
1178
|
+
----------
|
|
1179
|
+
tract_file : str | Path
|
|
1180
|
+
Path to the tractography file.
|
|
1181
|
+
ref_img : str | Path | None, optional
|
|
1182
|
+
Path to the reference image. If None, uses the reference image
|
|
1183
|
+
set during initialization or via `set_reference_image()`.
|
|
1184
|
+
views : list[str] | None, optional
|
|
1185
|
+
List of views to generate. Options: "coronal", "axial", "sagittal".
|
|
1186
|
+
If None, generates all three views. Default is None.
|
|
1187
|
+
output_dir : str | Path | None, optional
|
|
1188
|
+
Output directory for generated images. If None, uses the output
|
|
1189
|
+
directory set during initialization.
|
|
1190
|
+
figure_size : tuple[int, int], optional
|
|
1191
|
+
Size of the output images in pixels. Default is (800, 800).
|
|
1192
|
+
show_glass_brain : bool, optional
|
|
1193
|
+
Whether to show the glass brain outline. Default is True.
|
|
1194
|
+
|
|
1195
|
+
Returns
|
|
1196
|
+
-------
|
|
1197
|
+
dict[str, Path]
|
|
1198
|
+
Dictionary mapping view names to their output file paths.
|
|
1199
|
+
Keys: "coronal", "axial", "sagittal".
|
|
1200
|
+
|
|
1201
|
+
Raises
|
|
1202
|
+
------
|
|
1203
|
+
FileNotFoundError
|
|
1204
|
+
If required files are not found.
|
|
1205
|
+
InvalidInputError
|
|
1206
|
+
If no output directory is available or invalid view name.
|
|
1207
|
+
TractographyVisualizationError
|
|
1208
|
+
If image generation fails.
|
|
1209
|
+
|
|
1210
|
+
Examples
|
|
1211
|
+
--------
|
|
1212
|
+
>>> visualizer = TractographyVisualizer(
|
|
1213
|
+
... reference_image="t1w.nii.gz", output_directory="output/"
|
|
1214
|
+
... )
|
|
1215
|
+
>>> # Generate all views
|
|
1216
|
+
>>> views = visualizer.generate_anatomical_views("tract.trk")
|
|
1217
|
+
>>> # Generate specific views
|
|
1218
|
+
>>> views = visualizer.generate_anatomical_views(
|
|
1219
|
+
... "tract.trk", views=["coronal", "axial"]
|
|
1220
|
+
... )
|
|
1221
|
+
"""
|
|
1222
|
+
# Use standard anatomical view angles from utils
|
|
1223
|
+
view_angles = ANATOMICAL_VIEW_ANGLES
|
|
1224
|
+
|
|
1225
|
+
# Determine which views to generate
|
|
1226
|
+
if views is None:
|
|
1227
|
+
views_to_generate = list(view_angles.keys())
|
|
1228
|
+
else:
|
|
1229
|
+
views_to_generate = views
|
|
1230
|
+
# Validate view names
|
|
1231
|
+
invalid_views = [v for v in views_to_generate if v not in view_angles]
|
|
1232
|
+
if invalid_views:
|
|
1233
|
+
raise InvalidInputError(
|
|
1234
|
+
f"Invalid view names: {invalid_views}. Valid options: {list(view_angles.keys())}",
|
|
1235
|
+
)
|
|
1236
|
+
|
|
1237
|
+
# Get output directory
|
|
1238
|
+
if output_dir is None:
|
|
1239
|
+
if self._output_directory is None:
|
|
1240
|
+
raise InvalidInputError(
|
|
1241
|
+
"No output directory provided. Set it via constructor or "
|
|
1242
|
+
"set_output_directory() method, or pass it as an argument.",
|
|
1243
|
+
)
|
|
1244
|
+
output_dir = self._output_directory
|
|
1245
|
+
else:
|
|
1246
|
+
output_dir = Path(output_dir)
|
|
1247
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
1248
|
+
|
|
1249
|
+
if ref_img is None:
|
|
1250
|
+
if self._reference_image is None:
|
|
1251
|
+
raise InvalidInputError(
|
|
1252
|
+
"No reference image provided. Set it via constructor or "
|
|
1253
|
+
"set_reference_image() method, or pass it as an argument.",
|
|
1254
|
+
)
|
|
1255
|
+
ref_img = self._reference_image
|
|
1256
|
+
else:
|
|
1257
|
+
ref_img = Path(ref_img)
|
|
1258
|
+
|
|
1259
|
+
tract_name = self._extract_tract_name(tract_file)
|
|
1260
|
+
generated_views: dict[str, Path] = {}
|
|
1261
|
+
|
|
1262
|
+
# Check if all output files already exist BEFORE loading tract
|
|
1263
|
+
# This prevents unnecessary memory usage when files are already generated
|
|
1264
|
+
all_files_exist = True
|
|
1265
|
+
for view_name in views_to_generate:
|
|
1266
|
+
output_image = output_dir / f"{tract_name}_{view_name}.png"
|
|
1267
|
+
if output_image.exists():
|
|
1268
|
+
generated_views[view_name] = output_image
|
|
1269
|
+
logger.debug("Skipping generation of %s (file already exists)", output_image)
|
|
1270
|
+
else:
|
|
1271
|
+
all_files_exist = False
|
|
1272
|
+
|
|
1273
|
+
# If all files exist, return early without loading tract
|
|
1274
|
+
if all_files_exist:
|
|
1275
|
+
return generated_views
|
|
1276
|
+
|
|
1277
|
+
try:
|
|
1278
|
+
# Load tract only if we need to generate at least one view
|
|
1279
|
+
tract = load_trk(str(tract_file), "same", bbox_valid_check=False)
|
|
1280
|
+
tract.to_rasmm()
|
|
1281
|
+
ref_img_obj = nib.load(str(ref_img))
|
|
1282
|
+
tract_streamlines = transform_streamlines(
|
|
1283
|
+
tract.streamlines,
|
|
1284
|
+
np.linalg.inv(ref_img_obj.affine), # type: ignore[attr-defined]
|
|
1285
|
+
)
|
|
1286
|
+
|
|
1287
|
+
# Calculate colors based on streamline directions using utility function
|
|
1288
|
+
streamline_colors = calculate_direction_colors(tract_streamlines)
|
|
1289
|
+
|
|
1290
|
+
# Get centroid using utility function
|
|
1291
|
+
centroid = calculate_centroid(tract_streamlines)
|
|
1292
|
+
|
|
1293
|
+
# Log tract statistics for debugging
|
|
1294
|
+
total_points = sum(len(sl) for sl in tract_streamlines)
|
|
1295
|
+
avg_points = total_points / len(tract_streamlines) if tract_streamlines else 0
|
|
1296
|
+
max_points = max(len(sl) for sl in tract_streamlines) if tract_streamlines else 0
|
|
1297
|
+
logger.info(
|
|
1298
|
+
"Tract statistics: %d streamlines, %d total points, "
|
|
1299
|
+
"%.1f avg points/streamline, %d max points/streamline",
|
|
1300
|
+
len(tract_streamlines),
|
|
1301
|
+
total_points,
|
|
1302
|
+
avg_points,
|
|
1303
|
+
max_points,
|
|
1304
|
+
)
|
|
1305
|
+
|
|
1306
|
+
# Warn if tract is very dense (high risk of std::bad_alloc)
|
|
1307
|
+
if total_points > DENSE_TRACT_POINT_THRESHOLD:
|
|
1308
|
+
logger.warning(
|
|
1309
|
+
"Very dense tract detected (%d total points). "
|
|
1310
|
+
"This may cause std::bad_alloc errors due to memory fragmentation. "
|
|
1311
|
+
"Consider: (1) resampling streamlines to reduce points, "
|
|
1312
|
+
"(2) filtering with higher cci_threshold, (3) reducing image resolution.",
|
|
1313
|
+
total_points,
|
|
1314
|
+
)
|
|
1315
|
+
if max_points > VERY_LONG_STREAMLINE_POINT_THRESHOLD:
|
|
1316
|
+
logger.warning(
|
|
1317
|
+
"Very long streamline detected (%d points). "
|
|
1318
|
+
"VTK may need large contiguous memory allocation which can fail due to fragmentation.",
|
|
1319
|
+
max_points,
|
|
1320
|
+
)
|
|
1321
|
+
|
|
1322
|
+
# Clean up loaded objects that are no longer needed
|
|
1323
|
+
del tract, ref_img_obj
|
|
1324
|
+
|
|
1325
|
+
# Generate each requested view
|
|
1326
|
+
for view_name in views_to_generate:
|
|
1327
|
+
output_image = output_dir / f"{tract_name}_{view_name}.png"
|
|
1328
|
+
|
|
1329
|
+
# Skip if file already exists (already added to generated_views above)
|
|
1330
|
+
if output_image.exists():
|
|
1331
|
+
continue
|
|
1332
|
+
|
|
1333
|
+
# Check memory before creating VTK actor
|
|
1334
|
+
estimated_memory = _estimate_actor_memory_mb(tract_streamlines, figure_size)
|
|
1335
|
+
if not _check_memory_available(estimated_memory, safety_margin=0.3):
|
|
1336
|
+
logger.warning(
|
|
1337
|
+
"Insufficient memory to create actor for view %s. Skipping.",
|
|
1338
|
+
view_name,
|
|
1339
|
+
)
|
|
1340
|
+
continue
|
|
1341
|
+
|
|
1342
|
+
# Create scene using helper method
|
|
1343
|
+
scene, _ = self._create_scene(ref_img=ref_img, show_glass_brain=show_glass_brain)
|
|
1344
|
+
|
|
1345
|
+
# Create actor with original streamlines using helper method
|
|
1346
|
+
try:
|
|
1347
|
+
tract_actor = self._create_streamline_actor(tract_streamlines, streamline_colors)
|
|
1348
|
+
scene.add(tract_actor)
|
|
1349
|
+
except RuntimeError as e:
|
|
1350
|
+
# Catch std::bad_alloc and other VTK errors
|
|
1351
|
+
error_msg = str(e).lower()
|
|
1352
|
+
if "bad_alloc" in error_msg or "memory" in error_msg or "allocation" in error_msg:
|
|
1353
|
+
logger.exception(
|
|
1354
|
+
"VTK memory allocation failed for view %s (likely std::bad_alloc). "
|
|
1355
|
+
"Tract has %d streamlines with %d total points. "
|
|
1356
|
+
"Try reducing image resolution or filtering streamlines.",
|
|
1357
|
+
view_name,
|
|
1358
|
+
len(tract_streamlines),
|
|
1359
|
+
sum(len(sl) for sl in tract_streamlines),
|
|
1360
|
+
)
|
|
1361
|
+
scene.clear()
|
|
1362
|
+
continue
|
|
1363
|
+
raise
|
|
1364
|
+
|
|
1365
|
+
# Set camera position for anatomical view (no streamline rotation needed)
|
|
1366
|
+
# Calculate bbox for camera distance using utility function
|
|
1367
|
+
bbox_size = calculate_bbox_size(tract_streamlines)
|
|
1368
|
+
|
|
1369
|
+
# Use helper method to set camera
|
|
1370
|
+
self._set_anatomical_camera(
|
|
1371
|
+
scene,
|
|
1372
|
+
centroid,
|
|
1373
|
+
view_name,
|
|
1374
|
+
bbox_size=bbox_size,
|
|
1375
|
+
)
|
|
1376
|
+
|
|
1377
|
+
# Record the scene
|
|
1378
|
+
window.record(
|
|
1379
|
+
scene=scene,
|
|
1380
|
+
out_path=str(output_image),
|
|
1381
|
+
size=figure_size,
|
|
1382
|
+
)
|
|
1383
|
+
scene.clear()
|
|
1384
|
+
del tract_actor
|
|
1385
|
+
|
|
1386
|
+
generated_views[view_name] = output_image
|
|
1387
|
+
|
|
1388
|
+
# Clean up remaining large objects
|
|
1389
|
+
del tract_streamlines, streamline_colors
|
|
1390
|
+
gc.collect()
|
|
1391
|
+
|
|
1392
|
+
except TractographyVisualizationError:
|
|
1393
|
+
raise
|
|
1394
|
+
except (OSError, ValueError, RuntimeError) as e:
|
|
1395
|
+
raise TractographyVisualizationError(
|
|
1396
|
+
f"Failed to generate anatomical views: {e}",
|
|
1397
|
+
) from e
|
|
1398
|
+
else:
|
|
1399
|
+
return generated_views
|
|
1400
|
+
|
|
1401
|
+
def generate_atlas_views(
|
|
1402
|
+
self,
|
|
1403
|
+
atlas_file: str | Path,
|
|
1404
|
+
*,
|
|
1405
|
+
atlas_ref_img: str | Path | None = None,
|
|
1406
|
+
ref_img: str | Path | None = None, # Alias for atlas_ref_img
|
|
1407
|
+
flip_lr: bool = False,
|
|
1408
|
+
views: list[str] | None = None,
|
|
1409
|
+
output_dir: str | Path | None = None,
|
|
1410
|
+
figure_size: tuple[int, int] = (800, 800),
|
|
1411
|
+
show_glass_brain: bool = True,
|
|
1412
|
+
atlas_name: str | None = None,
|
|
1413
|
+
) -> dict[str, Path]:
|
|
1414
|
+
"""Generate anatomical views for an atlas tract.
|
|
1415
|
+
|
|
1416
|
+
Creates static images from coronal, axial, and sagittal views of the atlas
|
|
1417
|
+
tract. This is useful for comparing subject tracts to atlas tracts using
|
|
1418
|
+
the same viewing angles.
|
|
1419
|
+
|
|
1420
|
+
Parameters
|
|
1421
|
+
----------
|
|
1422
|
+
atlas_file : str | Path
|
|
1423
|
+
Path to the atlas tractography file.
|
|
1424
|
+
atlas_ref_img : str | Path | None, optional
|
|
1425
|
+
Path to the reference image that matches the atlas coordinate space
|
|
1426
|
+
(e.g., MNI template if atlas is in MNI space).
|
|
1427
|
+
This is important if the atlas is in a different space (e.g., MNI)
|
|
1428
|
+
than the subject reference image.
|
|
1429
|
+
flip_lr: bool, optional
|
|
1430
|
+
Whether to flip left-right (X-axis) when transforming atlas.
|
|
1431
|
+
This may be needed for some coordinate conventions or file formats
|
|
1432
|
+
where the left-right orientation differs. Try this if the atlas
|
|
1433
|
+
appears on the wrong side compared to the subject. Default is False.
|
|
1434
|
+
views : list[str] | None, optional
|
|
1435
|
+
List of views to generate. Options: "coronal", "axial", "sagittal".
|
|
1436
|
+
If None, generates all three views. Default is None.
|
|
1437
|
+
output_dir : str | Path | None, optional
|
|
1438
|
+
Output directory for generated images. If None, uses the output
|
|
1439
|
+
directory set during initialization.
|
|
1440
|
+
figure_size : tuple[int, int], optional
|
|
1441
|
+
Size of the output images in pixels. Default is (800, 800).
|
|
1442
|
+
show_glass_brain : bool, optional
|
|
1443
|
+
Whether to show the glass brain outline. Default is True.
|
|
1444
|
+
atlas_name : str | None, optional
|
|
1445
|
+
Name prefix for output files. If None, uses the stem of atlas_file.
|
|
1446
|
+
Default is None.
|
|
1447
|
+
|
|
1448
|
+
Returns
|
|
1449
|
+
-------
|
|
1450
|
+
dict[str, Path]
|
|
1451
|
+
Dictionary mapping view names to their output file paths.
|
|
1452
|
+
Keys: "coronal", "axial", "sagittal".
|
|
1453
|
+
|
|
1454
|
+
Raises
|
|
1455
|
+
------
|
|
1456
|
+
FileNotFoundError
|
|
1457
|
+
If required files are not found.
|
|
1458
|
+
InvalidInputError
|
|
1459
|
+
If no output directory is available or invalid view name.
|
|
1460
|
+
TractographyVisualizationError
|
|
1461
|
+
If image generation fails.
|
|
1462
|
+
|
|
1463
|
+
Examples
|
|
1464
|
+
--------
|
|
1465
|
+
>>> visualizer = TractographyVisualizer(
|
|
1466
|
+
... reference_image="t1w.nii.gz", output_directory="output/"
|
|
1467
|
+
... )
|
|
1468
|
+
>>> # Generate all atlas views
|
|
1469
|
+
>>> atlas_views = visualizer.generate_atlas_views("atlas_tract.trk")
|
|
1470
|
+
>>> # Generate specific views
|
|
1471
|
+
>>> atlas_views = visualizer.generate_atlas_views(
|
|
1472
|
+
... "atlas_tract.trk", views=["coronal", "axial"]
|
|
1473
|
+
... )
|
|
1474
|
+
"""
|
|
1475
|
+
# Determine reference image for atlas coordinate transformation
|
|
1476
|
+
# If atlas is in different space (e.g., MNI), use atlas_ref_img
|
|
1477
|
+
# Support both atlas_ref_img and ref_img (alias) for backward compatibility
|
|
1478
|
+
if atlas_ref_img is None and ref_img is not None:
|
|
1479
|
+
atlas_ref_img = ref_img
|
|
1480
|
+
atlas_ref_img = self._reference_image if atlas_ref_img is None else Path(atlas_ref_img)
|
|
1481
|
+
|
|
1482
|
+
# Use standard anatomical view angles from utils
|
|
1483
|
+
view_angles = ANATOMICAL_VIEW_ANGLES
|
|
1484
|
+
|
|
1485
|
+
# Determine which views to generate
|
|
1486
|
+
if views is None:
|
|
1487
|
+
views_to_generate = list(view_angles.keys())
|
|
1488
|
+
else:
|
|
1489
|
+
views_to_generate = views
|
|
1490
|
+
invalid_views = [v for v in views_to_generate if v not in view_angles]
|
|
1491
|
+
if invalid_views:
|
|
1492
|
+
raise InvalidInputError(
|
|
1493
|
+
f"Invalid view names: {invalid_views}. Valid options: {list(view_angles.keys())}",
|
|
1494
|
+
)
|
|
1495
|
+
|
|
1496
|
+
# Get output directory
|
|
1497
|
+
if output_dir is None:
|
|
1498
|
+
if self._output_directory is None:
|
|
1499
|
+
raise InvalidInputError(
|
|
1500
|
+
"No output directory provided. Set it via constructor or "
|
|
1501
|
+
"set_output_directory() method, or pass it as an argument.",
|
|
1502
|
+
)
|
|
1503
|
+
output_dir = self._output_directory
|
|
1504
|
+
else:
|
|
1505
|
+
output_dir = Path(output_dir)
|
|
1506
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
1507
|
+
|
|
1508
|
+
# Use provided atlas_name or derive from file
|
|
1509
|
+
atlas_name = self._extract_tract_name(atlas_file) if atlas_name is None else str(atlas_name)
|
|
1510
|
+
|
|
1511
|
+
generated_views: dict[str, Path] = {}
|
|
1512
|
+
|
|
1513
|
+
# Check if all output files already exist BEFORE loading atlas tract
|
|
1514
|
+
# This prevents unnecessary memory usage when files are already generated
|
|
1515
|
+
all_files_exist = True
|
|
1516
|
+
for view_name in views_to_generate:
|
|
1517
|
+
output_image = output_dir / f"{atlas_name}_atlas_{view_name}.png"
|
|
1518
|
+
if output_image.exists():
|
|
1519
|
+
generated_views[view_name] = output_image
|
|
1520
|
+
logger.debug("Skipping generation of %s (file already exists)", output_image)
|
|
1521
|
+
else:
|
|
1522
|
+
all_files_exist = False
|
|
1523
|
+
|
|
1524
|
+
# If all files exist, return early without loading atlas tract
|
|
1525
|
+
if all_files_exist:
|
|
1526
|
+
return generated_views
|
|
1527
|
+
|
|
1528
|
+
try:
|
|
1529
|
+
# Load atlas tract only if we need to generate at least one view
|
|
1530
|
+
atlas_tract = load_trk(str(atlas_file), "same", bbox_valid_check=False)
|
|
1531
|
+
atlas_tract.to_rasmm()
|
|
1532
|
+
|
|
1533
|
+
# Load reference image
|
|
1534
|
+
atlas_ref_img_obj = nib.load(str(atlas_ref_img))
|
|
1535
|
+
|
|
1536
|
+
# Transform atlas streamlines to visualization reference space
|
|
1537
|
+
# The atlas tract is already in RASMM after to_rasmm(), so we transform
|
|
1538
|
+
# directly to the visualization reference space
|
|
1539
|
+
atlas_streamlines = transform_streamlines(
|
|
1540
|
+
atlas_tract.streamlines,
|
|
1541
|
+
np.linalg.inv(atlas_ref_img_obj.affine), # type: ignore[attr-defined]
|
|
1542
|
+
)
|
|
1543
|
+
|
|
1544
|
+
# Apply left-right flip if needed (common when atlas is in MNI space)
|
|
1545
|
+
if flip_lr:
|
|
1546
|
+
# Flip X-axis (left-right) by negating X coordinates
|
|
1547
|
+
# This is needed when MNI and native space have different L/R conventions
|
|
1548
|
+
atlas_streamlines = Streamlines(
|
|
1549
|
+
[np.column_stack([-sl[:, 0], sl[:, 1], sl[:, 2]]) for sl in atlas_streamlines],
|
|
1550
|
+
)
|
|
1551
|
+
|
|
1552
|
+
# Also update centroid after flip
|
|
1553
|
+
all_points = np.vstack([np.array(sl) for sl in atlas_streamlines])
|
|
1554
|
+
centroid = np.mean(all_points, axis=0)
|
|
1555
|
+
|
|
1556
|
+
# Calculate colors based on streamline directions using utility function
|
|
1557
|
+
streamline_colors = calculate_direction_colors(atlas_streamlines)
|
|
1558
|
+
|
|
1559
|
+
# Get centroid using utility function (recalculate after any transformations)
|
|
1560
|
+
centroid = calculate_centroid(atlas_streamlines)
|
|
1561
|
+
|
|
1562
|
+
# Generate each requested view
|
|
1563
|
+
for view_name in views_to_generate:
|
|
1564
|
+
output_image = output_dir / f"{atlas_name}_atlas_{view_name}.png"
|
|
1565
|
+
|
|
1566
|
+
# Skip if file already exists (already added to generated_views above)
|
|
1567
|
+
if output_image.exists():
|
|
1568
|
+
continue
|
|
1569
|
+
|
|
1570
|
+
# Create scene using helper method
|
|
1571
|
+
scene, _ = self._create_scene(ref_img=atlas_ref_img, show_glass_brain=show_glass_brain)
|
|
1572
|
+
|
|
1573
|
+
# Create actor with original streamlines using helper method
|
|
1574
|
+
tract_actor = self._create_streamline_actor(atlas_streamlines, streamline_colors)
|
|
1575
|
+
scene.add(tract_actor)
|
|
1576
|
+
|
|
1577
|
+
# Set camera position for anatomical view using utility function
|
|
1578
|
+
bbox_size = calculate_bbox_size(atlas_streamlines)
|
|
1579
|
+
|
|
1580
|
+
# Use helper method to set camera
|
|
1581
|
+
self._set_anatomical_camera(
|
|
1582
|
+
scene,
|
|
1583
|
+
centroid,
|
|
1584
|
+
view_name,
|
|
1585
|
+
bbox_size=bbox_size,
|
|
1586
|
+
)
|
|
1587
|
+
|
|
1588
|
+
# Record the scene (this can also fail with std::bad_alloc)
|
|
1589
|
+
try:
|
|
1590
|
+
window.record(
|
|
1591
|
+
scene=scene,
|
|
1592
|
+
out_path=str(output_image),
|
|
1593
|
+
size=figure_size,
|
|
1594
|
+
)
|
|
1595
|
+
generated_views[view_name] = output_image
|
|
1596
|
+
except RuntimeError as e:
|
|
1597
|
+
# Catch std::bad_alloc and other VTK errors during rendering
|
|
1598
|
+
error_msg = str(e).lower()
|
|
1599
|
+
if "bad_alloc" in error_msg or "memory" in error_msg or "allocation" in error_msg:
|
|
1600
|
+
logger.exception(
|
|
1601
|
+
"VTK memory allocation failed during rendering for view %s (likely std::bad_alloc). "
|
|
1602
|
+
"Tract has %d streamlines with %d total points. "
|
|
1603
|
+
"Try reducing image resolution (figure_size) or processing with n_jobs=1.",
|
|
1604
|
+
view_name,
|
|
1605
|
+
len(atlas_streamlines),
|
|
1606
|
+
sum(len(sl) for sl in atlas_streamlines),
|
|
1607
|
+
)
|
|
1608
|
+
# Clean up and skip this view
|
|
1609
|
+
scene.clear()
|
|
1610
|
+
del tract_actor
|
|
1611
|
+
continue
|
|
1612
|
+
raise
|
|
1613
|
+
|
|
1614
|
+
scene.clear()
|
|
1615
|
+
del tract_actor
|
|
1616
|
+
|
|
1617
|
+
# Clean up large objects
|
|
1618
|
+
del atlas_tract, atlas_ref_img_obj, atlas_streamlines, streamline_colors
|
|
1619
|
+
gc.collect()
|
|
1620
|
+
|
|
1621
|
+
except TractographyVisualizationError:
|
|
1622
|
+
raise
|
|
1623
|
+
except (OSError, ValueError, RuntimeError) as e:
|
|
1624
|
+
raise TractographyVisualizationError(
|
|
1625
|
+
f"Failed to generate atlas views: {e}",
|
|
1626
|
+
) from e
|
|
1627
|
+
else:
|
|
1628
|
+
return generated_views
|
|
1629
|
+
|
|
1630
|
+
def plot_cci(
|
|
1631
|
+
self,
|
|
1632
|
+
cci: np.ndarray,
|
|
1633
|
+
keep_tract: StatefulTractogram,
|
|
1634
|
+
hist_file: str | Path,
|
|
1635
|
+
ref_img: str | Path | None = None,
|
|
1636
|
+
*,
|
|
1637
|
+
views: list[str] | None = None,
|
|
1638
|
+
output_dir: str | Path | None = None,
|
|
1639
|
+
figure_size: tuple[int, int] = (800, 800),
|
|
1640
|
+
show_glass_brain: bool = True,
|
|
1641
|
+
bins: int = 100,
|
|
1642
|
+
) -> dict[str, Path]:
|
|
1643
|
+
"""Plot CCI with anatomical views.
|
|
1644
|
+
|
|
1645
|
+
Generates anatomical views (coronal, axial, sagittal) with streamlines
|
|
1646
|
+
colored by CCI values, along with a histogram plot.
|
|
1647
|
+
|
|
1648
|
+
Parameters
|
|
1649
|
+
----------
|
|
1650
|
+
cci : np.ndarray
|
|
1651
|
+
CCI values array (one per streamline in keep_tract).
|
|
1652
|
+
keep_tract : StatefulTractogram
|
|
1653
|
+
Filtered tractogram to visualize (should match CCI array length).
|
|
1654
|
+
hist_file : str | Path
|
|
1655
|
+
Path to save the histogram plot.
|
|
1656
|
+
ref_img : str | Path | None, optional
|
|
1657
|
+
Path to the reference image. If None, uses the reference image
|
|
1658
|
+
set during initialization.
|
|
1659
|
+
views : list[str] | None, optional
|
|
1660
|
+
List of views to generate. Options: "coronal", "axial", "sagittal".
|
|
1661
|
+
If None, generates all three views. Default is None.
|
|
1662
|
+
output_dir : str | Path | None, optional
|
|
1663
|
+
Output directory for generated images. If None, uses the output
|
|
1664
|
+
directory set during initialization.
|
|
1665
|
+
figure_size : tuple[int, int], optional
|
|
1666
|
+
Size of the output images in pixels. Default is (800, 800).
|
|
1667
|
+
show_glass_brain : bool, optional
|
|
1668
|
+
Whether to show the glass brain outline. Default is True.
|
|
1669
|
+
bins : int, optional
|
|
1670
|
+
Number of bins for histogram. Default is 100.
|
|
1671
|
+
|
|
1672
|
+
Returns
|
|
1673
|
+
-------
|
|
1674
|
+
dict[str, Path]
|
|
1675
|
+
Dictionary mapping view names to their output file paths.
|
|
1676
|
+
Keys: "coronal", "axial", "sagittal", "histogram".
|
|
1677
|
+
|
|
1678
|
+
Raises
|
|
1679
|
+
------
|
|
1680
|
+
InvalidInputError
|
|
1681
|
+
If CCI array is empty or invalid, or length mismatch with tract.
|
|
1682
|
+
TractographyVisualizationError
|
|
1683
|
+
If plotting fails.
|
|
1684
|
+
"""
|
|
1685
|
+
if len(cci) == 0:
|
|
1686
|
+
raise InvalidInputError("CCI array is empty.")
|
|
1687
|
+
|
|
1688
|
+
# Get output directory
|
|
1689
|
+
if output_dir is None:
|
|
1690
|
+
if self._output_directory is None:
|
|
1691
|
+
raise InvalidInputError(
|
|
1692
|
+
"No output directory provided. Set it via constructor or "
|
|
1693
|
+
"set_output_directory() method, or pass it as an argument.",
|
|
1694
|
+
)
|
|
1695
|
+
output_dir = self._output_directory
|
|
1696
|
+
else:
|
|
1697
|
+
output_dir = Path(output_dir)
|
|
1698
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
1699
|
+
|
|
1700
|
+
if ref_img is None:
|
|
1701
|
+
if self._reference_image is None:
|
|
1702
|
+
raise InvalidInputError(
|
|
1703
|
+
"No reference image provided. Set it via constructor or "
|
|
1704
|
+
"set_reference_image() method, or pass it as an argument.",
|
|
1705
|
+
)
|
|
1706
|
+
ref_img = self._reference_image
|
|
1707
|
+
else:
|
|
1708
|
+
ref_img = Path(ref_img)
|
|
1709
|
+
|
|
1710
|
+
# Use standard anatomical view angles from utils
|
|
1711
|
+
view_angles = ANATOMICAL_VIEW_ANGLES
|
|
1712
|
+
|
|
1713
|
+
# Determine which views to generate
|
|
1714
|
+
if views is None:
|
|
1715
|
+
views_to_generate = list(view_angles.keys())
|
|
1716
|
+
else:
|
|
1717
|
+
views_to_generate = views
|
|
1718
|
+
invalid_views = [v for v in views_to_generate if v not in view_angles]
|
|
1719
|
+
if invalid_views:
|
|
1720
|
+
raise InvalidInputError(
|
|
1721
|
+
f"Invalid view names: {invalid_views}. Valid options: {list(view_angles.keys())}",
|
|
1722
|
+
)
|
|
1723
|
+
|
|
1724
|
+
generated_views: dict[str, Path] = {}
|
|
1725
|
+
|
|
1726
|
+
try:
|
|
1727
|
+
# Create histogram
|
|
1728
|
+
hist_path = Path(hist_file)
|
|
1729
|
+
hist_path.parent.mkdir(parents=True, exist_ok=True)
|
|
1730
|
+
|
|
1731
|
+
# Skip histogram if file already exists
|
|
1732
|
+
if not hist_path.exists():
|
|
1733
|
+
fig, ax = plt.subplots(1, figsize=(8, 6))
|
|
1734
|
+
ax.hist(cci, bins=bins, histtype="step")
|
|
1735
|
+
ax.set_xlabel("CCI")
|
|
1736
|
+
ax.set_ylabel("# streamlines")
|
|
1737
|
+
ax.set_title("CCI Distribution")
|
|
1738
|
+
ax.grid(visible=True, alpha=0.3)
|
|
1739
|
+
fig.savefig(str(hist_path), dpi=150, bbox_inches="tight")
|
|
1740
|
+
plt.close(fig)
|
|
1741
|
+
else:
|
|
1742
|
+
logger.debug("Skipping generation of %s (file already exists)", hist_path)
|
|
1743
|
+
|
|
1744
|
+
generated_views["histogram"] = hist_path
|
|
1745
|
+
|
|
1746
|
+
# Check if all view files already exist BEFORE processing tract
|
|
1747
|
+
# This prevents unnecessary memory usage when files are already generated
|
|
1748
|
+
all_view_files_exist = True
|
|
1749
|
+
for view_name in views_to_generate:
|
|
1750
|
+
output_image = output_dir / f"cci_{view_name}.png"
|
|
1751
|
+
if output_image.exists():
|
|
1752
|
+
generated_views[view_name] = output_image
|
|
1753
|
+
logger.debug("Skipping generation of %s (file already exists)", output_image)
|
|
1754
|
+
else:
|
|
1755
|
+
all_view_files_exist = False
|
|
1756
|
+
|
|
1757
|
+
# If all view files exist, return early without processing tract streamlines
|
|
1758
|
+
if all_view_files_exist:
|
|
1759
|
+
return generated_views
|
|
1760
|
+
|
|
1761
|
+
# Transform tract streamlines to reference space (only if needed)
|
|
1762
|
+
ref_img_obj = nib.load(str(ref_img))
|
|
1763
|
+
tract_streamlines = transform_streamlines(
|
|
1764
|
+
keep_tract.streamlines,
|
|
1765
|
+
np.linalg.inv(ref_img_obj.affine), # type: ignore[attr-defined]
|
|
1766
|
+
)
|
|
1767
|
+
|
|
1768
|
+
# Validate CCI array matches tract (critical for memory safety)
|
|
1769
|
+
if len(cci) != len(tract_streamlines):
|
|
1770
|
+
raise InvalidInputError(
|
|
1771
|
+
f"CCI array length ({len(cci)}) does not match "
|
|
1772
|
+
f"tractogram streamlines ({len(tract_streamlines)}). "
|
|
1773
|
+
f"This can cause memory issues.",
|
|
1774
|
+
)
|
|
1775
|
+
|
|
1776
|
+
# Calculate CCI colors for each streamline
|
|
1777
|
+
# Use the same color scheme as the original CCI visualization
|
|
1778
|
+
cci_min = float(cci.min())
|
|
1779
|
+
cci_max = float(cci.max())
|
|
1780
|
+
|
|
1781
|
+
# Create lookup table with same parameters as original
|
|
1782
|
+
hue = [0.5, 1]
|
|
1783
|
+
saturation = [0.0, 1.0]
|
|
1784
|
+
lut_cmap = actor.colormap_lookup_table(
|
|
1785
|
+
scale_range=(cci_min, cci_max / 4),
|
|
1786
|
+
hue_range=hue,
|
|
1787
|
+
saturation_range=saturation,
|
|
1788
|
+
)
|
|
1789
|
+
|
|
1790
|
+
bar = actor.scalar_bar(lookup_table=lut_cmap)
|
|
1791
|
+
|
|
1792
|
+
# Get centroid using utility function
|
|
1793
|
+
centroid = calculate_centroid(tract_streamlines)
|
|
1794
|
+
gc.collect()
|
|
1795
|
+
|
|
1796
|
+
# Generate each requested view
|
|
1797
|
+
for view_name in views_to_generate:
|
|
1798
|
+
output_image = output_dir / f"cci_{view_name}.png"
|
|
1799
|
+
|
|
1800
|
+
# Skip if file already exists (already added to generated_views above)
|
|
1801
|
+
if output_image.exists():
|
|
1802
|
+
continue
|
|
1803
|
+
|
|
1804
|
+
# Create scene
|
|
1805
|
+
scene = window.Scene()
|
|
1806
|
+
scene.SetBackground(0.5, 0.5, 0.5)
|
|
1807
|
+
scene.add(bar)
|
|
1808
|
+
|
|
1809
|
+
# Add glass brain if requested (no rotation needed - camera handles view)
|
|
1810
|
+
brain_actor = None
|
|
1811
|
+
if show_glass_brain:
|
|
1812
|
+
brain_actor = self.get_glass_brain(ref_img)
|
|
1813
|
+
scene.add(brain_actor)
|
|
1814
|
+
|
|
1815
|
+
# Check memory before creating large VTK actor
|
|
1816
|
+
estimated_memory = _estimate_actor_memory_mb(tract_streamlines, figure_size)
|
|
1817
|
+
if not _check_memory_available(estimated_memory, safety_margin=0.3):
|
|
1818
|
+
logger.warning(
|
|
1819
|
+
"Insufficient memory to create actor for %d streamlines. Skipping view %s",
|
|
1820
|
+
len(tract_streamlines),
|
|
1821
|
+
view_name,
|
|
1822
|
+
)
|
|
1823
|
+
continue
|
|
1824
|
+
|
|
1825
|
+
# Use CCI array directly with original streamlines (no rotation - camera handles view)
|
|
1826
|
+
try:
|
|
1827
|
+
tract_actor = actor.line(
|
|
1828
|
+
tract_streamlines,
|
|
1829
|
+
colors=cci,
|
|
1830
|
+
linewidth=0.1,
|
|
1831
|
+
lookup_colormap=lut_cmap,
|
|
1832
|
+
)
|
|
1833
|
+
scene.add(tract_actor)
|
|
1834
|
+
except RuntimeError as e:
|
|
1835
|
+
# Catch std::bad_alloc and other VTK errors
|
|
1836
|
+
error_msg = str(e).lower()
|
|
1837
|
+
if "bad_alloc" in error_msg or "memory" in error_msg or "allocation" in error_msg:
|
|
1838
|
+
logger.exception(
|
|
1839
|
+
"VTK memory allocation failed for view %s (likely std::bad_alloc). "
|
|
1840
|
+
"Try reducing tract size, image resolution, or using n_jobs=1.",
|
|
1841
|
+
view_name,
|
|
1842
|
+
)
|
|
1843
|
+
# Clean up and skip this view
|
|
1844
|
+
scene.clear()
|
|
1845
|
+
continue
|
|
1846
|
+
raise
|
|
1847
|
+
|
|
1848
|
+
# Set camera position for anatomical view using utility function
|
|
1849
|
+
bbox_size = calculate_bbox_size(tract_streamlines)
|
|
1850
|
+
|
|
1851
|
+
# Use helper method to set camera
|
|
1852
|
+
self._set_anatomical_camera(
|
|
1853
|
+
scene,
|
|
1854
|
+
centroid,
|
|
1855
|
+
view_name,
|
|
1856
|
+
bbox_size=bbox_size,
|
|
1857
|
+
)
|
|
1858
|
+
|
|
1859
|
+
# Record the scene (this can also fail with std::bad_alloc)
|
|
1860
|
+
try:
|
|
1861
|
+
window.record(
|
|
1862
|
+
scene=scene,
|
|
1863
|
+
out_path=str(output_image),
|
|
1864
|
+
size=figure_size,
|
|
1865
|
+
)
|
|
1866
|
+
generated_views[view_name] = output_image
|
|
1867
|
+
except RuntimeError as e:
|
|
1868
|
+
# Catch std::bad_alloc and other VTK errors during rendering
|
|
1869
|
+
error_msg = str(e).lower()
|
|
1870
|
+
if "bad_alloc" in error_msg or "memory" in error_msg or "allocation" in error_msg:
|
|
1871
|
+
logger.exception(
|
|
1872
|
+
"VTK memory allocation failed during rendering for view %s (likely std::bad_alloc). "
|
|
1873
|
+
"Tract has %d streamlines with %d total points. "
|
|
1874
|
+
"Try reducing image resolution (figure_size) or processing with n_jobs=1.",
|
|
1875
|
+
view_name,
|
|
1876
|
+
len(tract_streamlines),
|
|
1877
|
+
sum(len(sl) for sl in tract_streamlines),
|
|
1878
|
+
)
|
|
1879
|
+
# Clean up and skip this view
|
|
1880
|
+
scene.clear()
|
|
1881
|
+
del tract_actor
|
|
1882
|
+
if brain_actor is not None:
|
|
1883
|
+
del brain_actor
|
|
1884
|
+
del scene
|
|
1885
|
+
gc.collect()
|
|
1886
|
+
continue
|
|
1887
|
+
raise
|
|
1888
|
+
|
|
1889
|
+
# Explicitly clean up scene and actors to free memory
|
|
1890
|
+
# VTK/FURY objects can hold circular references, so explicit cleanup is critical
|
|
1891
|
+
scene.clear()
|
|
1892
|
+
del tract_actor
|
|
1893
|
+
if brain_actor is not None:
|
|
1894
|
+
del brain_actor
|
|
1895
|
+
del scene
|
|
1896
|
+
|
|
1897
|
+
# Force garbage collection between views to free memory
|
|
1898
|
+
# This is critical for large tractograms to prevent OOM kills
|
|
1899
|
+
gc.collect()
|
|
1900
|
+
|
|
1901
|
+
generated_views[view_name] = output_image
|
|
1902
|
+
|
|
1903
|
+
# Clean up large objects after all views are generated
|
|
1904
|
+
del keep_tract, tract_streamlines, cci
|
|
1905
|
+
gc.collect()
|
|
1906
|
+
|
|
1907
|
+
except TractographyVisualizationError:
|
|
1908
|
+
raise
|
|
1909
|
+
except (OSError, ValueError, RuntimeError) as e:
|
|
1910
|
+
raise TractographyVisualizationError(f"Failed to plot CCI: {e}") from e
|
|
1911
|
+
else:
|
|
1912
|
+
return generated_views
|
|
1913
|
+
|
|
1914
|
+
def plot_afq(
|
|
1915
|
+
self,
|
|
1916
|
+
metric_file: str | Path,
|
|
1917
|
+
metric_str: str,
|
|
1918
|
+
tract_file: str | Path,
|
|
1919
|
+
atlas_file: str | Path,
|
|
1920
|
+
ref_img: str | Path | None = None,
|
|
1921
|
+
*,
|
|
1922
|
+
views: list[str] | None = None,
|
|
1923
|
+
output_dir: str | Path | None = None,
|
|
1924
|
+
figure_size: tuple[int, int] = (800, 800),
|
|
1925
|
+
show_glass_brain: bool = True,
|
|
1926
|
+
colormap: str = "Spectral",
|
|
1927
|
+
) -> dict[str, Path]:
|
|
1928
|
+
"""Plot AFQ profile with anatomical views.
|
|
1929
|
+
|
|
1930
|
+
Generates anatomical views (coronal, axial, sagittal) with streamlines
|
|
1931
|
+
colored by AFQ profile values.
|
|
1932
|
+
|
|
1933
|
+
Parameters
|
|
1934
|
+
----------
|
|
1935
|
+
metric_file : str | Path
|
|
1936
|
+
Path to the metric image file.
|
|
1937
|
+
metric_str : str
|
|
1938
|
+
Name of the metric (e.g., "FA", "MD") for labeling.
|
|
1939
|
+
tract_file : str | Path
|
|
1940
|
+
Path to the tractography file.
|
|
1941
|
+
atlas_file : str | Path
|
|
1942
|
+
Path to the atlas tractography file.
|
|
1943
|
+
ref_img : str | Path | None, optional
|
|
1944
|
+
Path to the reference image. If None, uses the reference image
|
|
1945
|
+
set during initialization.
|
|
1946
|
+
views : list[str] | None, optional
|
|
1947
|
+
List of views to generate. Options: "coronal", "axial", "sagittal".
|
|
1948
|
+
If None, generates all three views. Default is None.
|
|
1949
|
+
output_dir : str | Path | None, optional
|
|
1950
|
+
Output directory for generated images. If None, uses the output
|
|
1951
|
+
directory set during initialization.
|
|
1952
|
+
figure_size : tuple[int, int], optional
|
|
1953
|
+
Size of the output images in pixels. Default is (800, 800).
|
|
1954
|
+
show_glass_brain : bool, optional
|
|
1955
|
+
Whether to show the glass brain outline. Default is True.
|
|
1956
|
+
colormap : str, optional
|
|
1957
|
+
Name of the colormap to use for AFQ profile values.
|
|
1958
|
+
Default is "Spectral".
|
|
1959
|
+
|
|
1960
|
+
Returns
|
|
1961
|
+
-------
|
|
1962
|
+
dict[str, Path]
|
|
1963
|
+
Dictionary mapping view names to their output file paths.
|
|
1964
|
+
Also includes "profile_plot" key for the profile line plot.
|
|
1965
|
+
|
|
1966
|
+
Raises
|
|
1967
|
+
------
|
|
1968
|
+
FileNotFoundError
|
|
1969
|
+
If required files are not found.
|
|
1970
|
+
InvalidInputError
|
|
1971
|
+
If no output directory is available.
|
|
1972
|
+
TractographyVisualizationError
|
|
1973
|
+
If image generation fails.
|
|
1974
|
+
"""
|
|
1975
|
+
# Calculate AFQ profile
|
|
1976
|
+
profile = self.weighted_afq(tract_file, atlas_file, metric_file)
|
|
1977
|
+
|
|
1978
|
+
# Use standard anatomical view angles from utils
|
|
1979
|
+
view_angles = ANATOMICAL_VIEW_ANGLES
|
|
1980
|
+
|
|
1981
|
+
# Determine which views to generate
|
|
1982
|
+
if views is None:
|
|
1983
|
+
views_to_generate = list(view_angles.keys())
|
|
1984
|
+
else:
|
|
1985
|
+
views_to_generate = views
|
|
1986
|
+
invalid_views = [v for v in views_to_generate if v not in view_angles]
|
|
1987
|
+
if invalid_views:
|
|
1988
|
+
raise InvalidInputError(
|
|
1989
|
+
f"Invalid view names: {invalid_views}. Valid options: {list(view_angles.keys())}",
|
|
1990
|
+
)
|
|
1991
|
+
|
|
1992
|
+
# Get output directory
|
|
1993
|
+
if output_dir is None:
|
|
1994
|
+
if self._output_directory is None:
|
|
1995
|
+
raise InvalidInputError(
|
|
1996
|
+
"No output directory provided. Set it via constructor or "
|
|
1997
|
+
"set_output_directory() method, or pass it as an argument.",
|
|
1998
|
+
)
|
|
1999
|
+
output_dir = self._output_directory
|
|
2000
|
+
else:
|
|
2001
|
+
output_dir = Path(output_dir)
|
|
2002
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
2003
|
+
|
|
2004
|
+
if ref_img is None:
|
|
2005
|
+
if self._reference_image is None:
|
|
2006
|
+
raise InvalidInputError(
|
|
2007
|
+
"No reference image provided. Set it via constructor or "
|
|
2008
|
+
"set_reference_image() method, or pass it as an argument.",
|
|
2009
|
+
)
|
|
2010
|
+
ref_img = self._reference_image
|
|
2011
|
+
else:
|
|
2012
|
+
ref_img = Path(ref_img)
|
|
2013
|
+
|
|
2014
|
+
tract_name = self._extract_tract_name(tract_file)
|
|
2015
|
+
generated_views: dict[str, Path] = {}
|
|
2016
|
+
|
|
2017
|
+
# Check profile plot path
|
|
2018
|
+
profile_plot_path = output_dir / f"{tract_name}_{metric_str}_profile.png"
|
|
2019
|
+
if profile_plot_path.exists():
|
|
2020
|
+
logger.debug("Skipping generation of %s (file already exists)", profile_plot_path)
|
|
2021
|
+
generated_views["profile_plot"] = profile_plot_path
|
|
2022
|
+
|
|
2023
|
+
# Check if all view files already exist BEFORE loading tract
|
|
2024
|
+
# This prevents unnecessary memory usage when files are already generated
|
|
2025
|
+
all_view_files_exist = True
|
|
2026
|
+
for view_name in views_to_generate:
|
|
2027
|
+
output_image = output_dir / f"{tract_name}_{metric_str}_{view_name}.png"
|
|
2028
|
+
if output_image.exists():
|
|
2029
|
+
generated_views[view_name] = output_image
|
|
2030
|
+
logger.debug("Skipping generation of %s (file already exists)", output_image)
|
|
2031
|
+
else:
|
|
2032
|
+
all_view_files_exist = False
|
|
2033
|
+
|
|
2034
|
+
# If all view files exist and profile plot exists, return early without loading tract
|
|
2035
|
+
if all_view_files_exist and profile_plot_path.exists():
|
|
2036
|
+
return generated_views
|
|
2037
|
+
|
|
2038
|
+
try:
|
|
2039
|
+
# Load tract only if we need to generate at least one view or profile plot
|
|
2040
|
+
tract = load_trk(str(tract_file), "same", bbox_valid_check=False)
|
|
2041
|
+
tract.to_rasmm()
|
|
2042
|
+
ref_img_obj = nib.load(str(ref_img))
|
|
2043
|
+
tract_streamlines = transform_streamlines(
|
|
2044
|
+
tract.streamlines,
|
|
2045
|
+
np.linalg.inv(ref_img_obj.affine), # type: ignore[attr-defined]
|
|
2046
|
+
)
|
|
2047
|
+
|
|
2048
|
+
# Calculate AFQ profile colors for each streamline
|
|
2049
|
+
# Store per-point colors for each streamline
|
|
2050
|
+
streamline_point_colors = []
|
|
2051
|
+
for sl in tract_streamlines:
|
|
2052
|
+
sl_array = np.array(sl)
|
|
2053
|
+
# Interpolate profile values to match streamline points
|
|
2054
|
+
interpolated_values = np.interp(
|
|
2055
|
+
np.linspace(0, 1, len(sl_array)),
|
|
2056
|
+
np.linspace(0, 1, len(profile)),
|
|
2057
|
+
profile,
|
|
2058
|
+
)
|
|
2059
|
+
# Create colormap colors for each point
|
|
2060
|
+
point_colors = create_colormap(interpolated_values, name=colormap)
|
|
2061
|
+
streamline_point_colors.append(point_colors)
|
|
2062
|
+
|
|
2063
|
+
# Get centroid using utility function
|
|
2064
|
+
centroid = calculate_centroid(tract_streamlines)
|
|
2065
|
+
|
|
2066
|
+
# Generate each requested view
|
|
2067
|
+
for view_name in views_to_generate:
|
|
2068
|
+
output_image = output_dir / f"{tract_name}_{metric_str}_{view_name}.png"
|
|
2069
|
+
|
|
2070
|
+
# Skip if file already exists (already added to generated_views above)
|
|
2071
|
+
if output_image.exists():
|
|
2072
|
+
continue
|
|
2073
|
+
|
|
2074
|
+
# Create scene using helper method
|
|
2075
|
+
scene, _ = self._create_scene(ref_img=ref_img, show_glass_brain=show_glass_brain)
|
|
2076
|
+
|
|
2077
|
+
# Create actors with AFQ profile colors using original streamlines
|
|
2078
|
+
# Colors are already calculated per point, so we can use them directly
|
|
2079
|
+
for _i, (sl, point_colors) in enumerate(
|
|
2080
|
+
zip(tract_streamlines, streamline_point_colors),
|
|
2081
|
+
):
|
|
2082
|
+
line_actor = actor.line([sl], colors=point_colors)
|
|
2083
|
+
scene.add(line_actor)
|
|
2084
|
+
|
|
2085
|
+
# Set camera position for anatomical view using utility function
|
|
2086
|
+
bbox_size = calculate_bbox_size(tract_streamlines)
|
|
2087
|
+
|
|
2088
|
+
# Use helper method to set camera
|
|
2089
|
+
self._set_anatomical_camera(
|
|
2090
|
+
scene,
|
|
2091
|
+
centroid,
|
|
2092
|
+
view_name,
|
|
2093
|
+
bbox_size=bbox_size,
|
|
2094
|
+
)
|
|
2095
|
+
|
|
2096
|
+
# Record the scene (this can also fail with std::bad_alloc)
|
|
2097
|
+
try:
|
|
2098
|
+
window.record(
|
|
2099
|
+
scene=scene,
|
|
2100
|
+
out_path=str(output_image),
|
|
2101
|
+
size=figure_size,
|
|
2102
|
+
)
|
|
2103
|
+
generated_views[view_name] = output_image
|
|
2104
|
+
except RuntimeError as e:
|
|
2105
|
+
# Catch std::bad_alloc and other VTK errors during rendering
|
|
2106
|
+
error_msg = str(e).lower()
|
|
2107
|
+
if "bad_alloc" in error_msg or "memory" in error_msg or "allocation" in error_msg:
|
|
2108
|
+
logger.exception(
|
|
2109
|
+
"VTK memory allocation failed during rendering for view %s (likely std::bad_alloc). "
|
|
2110
|
+
"Tract has %d streamlines with %d total points. "
|
|
2111
|
+
"Try reducing image resolution (figure_size) or processing with n_jobs=1.",
|
|
2112
|
+
view_name,
|
|
2113
|
+
len(tract_streamlines),
|
|
2114
|
+
sum(len(sl) for sl in tract_streamlines),
|
|
2115
|
+
)
|
|
2116
|
+
# Clean up and skip this view
|
|
2117
|
+
scene.clear()
|
|
2118
|
+
continue
|
|
2119
|
+
raise
|
|
2120
|
+
|
|
2121
|
+
scene.clear()
|
|
2122
|
+
|
|
2123
|
+
# Also create the profile line plot (if it doesn't already exist)
|
|
2124
|
+
if not profile_plot_path.exists():
|
|
2125
|
+
fig, ax = plt.subplots(1, figsize=(8, 6))
|
|
2126
|
+
ax.plot(profile)
|
|
2127
|
+
ax.set_xlabel("Node along tract")
|
|
2128
|
+
ax.set_ylabel(metric_str)
|
|
2129
|
+
ax.set_title(f"AFQ Profile: {metric_str}")
|
|
2130
|
+
ax.grid(visible=True, alpha=0.3)
|
|
2131
|
+
fig.savefig(str(profile_plot_path), dpi=150, bbox_inches="tight")
|
|
2132
|
+
plt.close(fig)
|
|
2133
|
+
else:
|
|
2134
|
+
logger.debug("Skipping generation of %s (file already exists)", profile_plot_path)
|
|
2135
|
+
|
|
2136
|
+
except TractographyVisualizationError:
|
|
2137
|
+
raise
|
|
2138
|
+
except (OSError, ValueError, RuntimeError) as e:
|
|
2139
|
+
raise TractographyVisualizationError(
|
|
2140
|
+
f"Failed to plot AFQ profile: {e}",
|
|
2141
|
+
) from e
|
|
2142
|
+
else:
|
|
2143
|
+
return generated_views
|
|
2144
|
+
|
|
2145
|
+
def calculate_shape_similarity(
|
|
2146
|
+
self,
|
|
2147
|
+
tract_file: str | Path,
|
|
2148
|
+
atlas_file: str | Path,
|
|
2149
|
+
*,
|
|
2150
|
+
atlas_ref_img: str | Path | None = None,
|
|
2151
|
+
flip_lr: bool = False,
|
|
2152
|
+
clust_thr: tuple[float, float, float] = (5, 3, 1.5),
|
|
2153
|
+
threshold: float = 6,
|
|
2154
|
+
rng: np.random.Generator | None = None,
|
|
2155
|
+
) -> float:
|
|
2156
|
+
"""Calculate shape similarity score between tract and atlas.
|
|
2157
|
+
|
|
2158
|
+
Uses DIPY's bundle_shape_similarity function with the Bundle Adjacency (BA) metric
|
|
2159
|
+
to compute how closely the shapes of two bundles align.
|
|
2160
|
+
|
|
2161
|
+
Parameters
|
|
2162
|
+
----------
|
|
2163
|
+
tract_file : str | Path
|
|
2164
|
+
Path to the subject tractography file.
|
|
2165
|
+
atlas_file : str | Path
|
|
2166
|
+
Path to the atlas tractography file.
|
|
2167
|
+
atlas_ref_img : str | Path | None, optional
|
|
2168
|
+
Path to the reference image that matches the atlas coordinate space
|
|
2169
|
+
(e.g., MNI template if atlas is in MNI space). If None, assumes atlas
|
|
2170
|
+
is in the same space as the subject tract. This is important if the
|
|
2171
|
+
atlas is in a different space (e.g., MNI) than the subject.
|
|
2172
|
+
flip_lr : bool, optional
|
|
2173
|
+
Whether to flip left-right (X-axis) when transforming atlas.
|
|
2174
|
+
This may be needed for some coordinate conventions or file formats
|
|
2175
|
+
where the left-right orientation differs. Default is False.
|
|
2176
|
+
clust_thr : tuple[float, float, float], optional
|
|
2177
|
+
Clustering thresholds for QuickBundlesX used internally.
|
|
2178
|
+
Default is (5, 3, 1.5).
|
|
2179
|
+
threshold : float, optional
|
|
2180
|
+
Threshold controlling the strictness of the shape similarity assessment.
|
|
2181
|
+
A smaller threshold requires the bundles to be more similar to achieve
|
|
2182
|
+
a higher score. Default is 6.
|
|
2183
|
+
rng : np.random.Generator | None, optional
|
|
2184
|
+
Random number generator. If None, creates a new one. Default is None.
|
|
2185
|
+
|
|
2186
|
+
Returns
|
|
2187
|
+
-------
|
|
2188
|
+
float
|
|
2189
|
+
Bundle similarity score (BA value). Higher values indicate greater
|
|
2190
|
+
similarity between the bundles.
|
|
2191
|
+
|
|
2192
|
+
Raises
|
|
2193
|
+
------
|
|
2194
|
+
FileNotFoundError
|
|
2195
|
+
If required files are not found.
|
|
2196
|
+
InvalidInputError
|
|
2197
|
+
If tracts are empty or invalid.
|
|
2198
|
+
TractographyVisualizationError
|
|
2199
|
+
If calculation fails.
|
|
2200
|
+
|
|
2201
|
+
Examples
|
|
2202
|
+
--------
|
|
2203
|
+
>>> visualizer = TractographyVisualizer()
|
|
2204
|
+
>>> score = visualizer.calculate_shape_similarity(
|
|
2205
|
+
... "subject_tract.trk", "atlas_tract.trk"
|
|
2206
|
+
... )
|
|
2207
|
+
>>> print(f"Shape similarity score: {score}")
|
|
2208
|
+
"""
|
|
2209
|
+
try:
|
|
2210
|
+
# Load both tracts
|
|
2211
|
+
tract = load_trk(str(tract_file), "same", bbox_valid_check=False)
|
|
2212
|
+
tract.to_rasmm()
|
|
2213
|
+
|
|
2214
|
+
atlas_tract = load_trk(str(atlas_file), "same", bbox_valid_check=False)
|
|
2215
|
+
atlas_tract.to_rasmm()
|
|
2216
|
+
|
|
2217
|
+
# Check if tracts are empty
|
|
2218
|
+
if not tract.streamlines or len(tract.streamlines) == 0:
|
|
2219
|
+
raise InvalidInputError("Subject tract is empty.")
|
|
2220
|
+
if not atlas_tract.streamlines or len(atlas_tract.streamlines) == 0:
|
|
2221
|
+
raise InvalidInputError("Atlas tract is empty.")
|
|
2222
|
+
|
|
2223
|
+
# Transform atlas to subject space if needed
|
|
2224
|
+
if atlas_ref_img is not None:
|
|
2225
|
+
atlas_ref_img_obj = nib.load(str(atlas_ref_img))
|
|
2226
|
+
|
|
2227
|
+
# Transform atlas streamlines to subject reference space
|
|
2228
|
+
atlas_streamlines = transform_streamlines(
|
|
2229
|
+
atlas_tract.streamlines,
|
|
2230
|
+
np.linalg.inv(atlas_ref_img_obj.affine), # type: ignore[attr-defined]
|
|
2231
|
+
)
|
|
2232
|
+
|
|
2233
|
+
# Apply left-right flip if needed
|
|
2234
|
+
if flip_lr:
|
|
2235
|
+
atlas_streamlines = Streamlines(
|
|
2236
|
+
[np.column_stack([-sl[:, 0], sl[:, 1], sl[:, 2]]) for sl in atlas_streamlines],
|
|
2237
|
+
)
|
|
2238
|
+
else:
|
|
2239
|
+
# No transformation needed, use streamlines directly
|
|
2240
|
+
atlas_streamlines = atlas_tract.streamlines
|
|
2241
|
+
if flip_lr:
|
|
2242
|
+
atlas_streamlines = Streamlines(
|
|
2243
|
+
[np.column_stack([-sl[:, 0], sl[:, 1], sl[:, 2]]) for sl in atlas_streamlines],
|
|
2244
|
+
)
|
|
2245
|
+
|
|
2246
|
+
# Use subject tract streamlines directly (already in RASMM)
|
|
2247
|
+
subject_streamlines = tract.streamlines
|
|
2248
|
+
|
|
2249
|
+
# Create random number generator if not provided
|
|
2250
|
+
if rng is None:
|
|
2251
|
+
rng = np.random.default_rng()
|
|
2252
|
+
|
|
2253
|
+
# Calculate shape similarity using DIPY's function
|
|
2254
|
+
similarity_score = bundle_shape_similarity(
|
|
2255
|
+
subject_streamlines,
|
|
2256
|
+
atlas_streamlines,
|
|
2257
|
+
rng,
|
|
2258
|
+
clust_thr=clust_thr,
|
|
2259
|
+
threshold=threshold,
|
|
2260
|
+
)
|
|
2261
|
+
|
|
2262
|
+
return float(similarity_score)
|
|
2263
|
+
except TractographyVisualizationError:
|
|
2264
|
+
raise
|
|
2265
|
+
except (OSError, ValueError, RuntimeError, IndexError) as e:
|
|
2266
|
+
raise TractographyVisualizationError(
|
|
2267
|
+
f"Failed to calculate shape similarity: {e}",
|
|
2268
|
+
) from e
|
|
2269
|
+
|
|
2270
|
+
def visualize_shape_similarity(
|
|
2271
|
+
self,
|
|
2272
|
+
tract_file: str | Path,
|
|
2273
|
+
atlas_file: str | Path,
|
|
2274
|
+
*,
|
|
2275
|
+
atlas_ref_img: str | Path | None = None,
|
|
2276
|
+
flip_lr: bool = False,
|
|
2277
|
+
views: list[str] | None = None,
|
|
2278
|
+
output_dir: str | Path | None = None,
|
|
2279
|
+
figure_size: tuple[int, int] = (800, 800),
|
|
2280
|
+
show_glass_brain: bool = True,
|
|
2281
|
+
subject_color: tuple[float, float, float] = (1.0, 0.0, 0.0), # Red
|
|
2282
|
+
atlas_color: tuple[float, float, float] = (0.0, 0.0, 1.0), # Blue
|
|
2283
|
+
) -> dict[str, Path]:
|
|
2284
|
+
"""Visualize shape similarity by overlaying subject and atlas tracts.
|
|
2285
|
+
|
|
2286
|
+
Generates anatomical views (coronal, axial, sagittal) showing both tracts
|
|
2287
|
+
overlaid with different colors to visualize their shape similarity.
|
|
2288
|
+
|
|
2289
|
+
Parameters
|
|
2290
|
+
----------
|
|
2291
|
+
tract_file : str | Path
|
|
2292
|
+
Path to the subject tractography file.
|
|
2293
|
+
atlas_file : str | Path
|
|
2294
|
+
Path to the atlas tractography file.
|
|
2295
|
+
atlas_ref_img : str | Path | None, optional
|
|
2296
|
+
Path to the reference image that matches the atlas coordinate space
|
|
2297
|
+
(e.g., MNI template if atlas is in MNI space).
|
|
2298
|
+
flip_lr : bool, optional
|
|
2299
|
+
Whether to flip left-right (X-axis) when transforming atlas.
|
|
2300
|
+
Default is False.
|
|
2301
|
+
views : list[str] | None, optional
|
|
2302
|
+
List of views to generate. Options: "coronal", "axial", "sagittal".
|
|
2303
|
+
If None, generates all three views. Default is None.
|
|
2304
|
+
output_dir : str | Path | None, optional
|
|
2305
|
+
Output directory for generated images. If None, uses the output
|
|
2306
|
+
directory set during initialization.
|
|
2307
|
+
figure_size : tuple[int, int], optional
|
|
2308
|
+
Size of the output images in pixels. Default is (800, 800).
|
|
2309
|
+
show_glass_brain : bool, optional
|
|
2310
|
+
Whether to show the glass brain outline. Default is True.
|
|
2311
|
+
subject_color : tuple[float, float, float], optional
|
|
2312
|
+
RGB color for subject tract (0-1 range). Default is (1.0, 0.0, 0.0) (red).
|
|
2313
|
+
atlas_color : tuple[float, float, float], optional
|
|
2314
|
+
RGB color for atlas tract (0-1 range). Default is (0.0, 0.0, 1.0) (blue).
|
|
2315
|
+
|
|
2316
|
+
Returns
|
|
2317
|
+
-------
|
|
2318
|
+
dict[str, Path]
|
|
2319
|
+
Dictionary mapping view names to their output file paths.
|
|
2320
|
+
Keys: "coronal", "axial", "sagittal".
|
|
2321
|
+
|
|
2322
|
+
Raises
|
|
2323
|
+
------
|
|
2324
|
+
FileNotFoundError
|
|
2325
|
+
If required files are not found.
|
|
2326
|
+
InvalidInputError
|
|
2327
|
+
If no output directory is available or invalid view name.
|
|
2328
|
+
TractographyVisualizationError
|
|
2329
|
+
If visualization fails.
|
|
2330
|
+
|
|
2331
|
+
Examples
|
|
2332
|
+
--------
|
|
2333
|
+
>>> visualizer = TractographyVisualizer(
|
|
2334
|
+
... reference_image="t1w.nii.gz", output_directory="output/"
|
|
2335
|
+
... )
|
|
2336
|
+
>>> views = visualizer.visualize_shape_similarity(
|
|
2337
|
+
... "subject_tract.trk", "atlas_tract.trk"
|
|
2338
|
+
... )
|
|
2339
|
+
"""
|
|
2340
|
+
# Use standard anatomical view angles from utils
|
|
2341
|
+
view_angles = ANATOMICAL_VIEW_ANGLES
|
|
2342
|
+
|
|
2343
|
+
# Determine which views to generate
|
|
2344
|
+
if views is None:
|
|
2345
|
+
views_to_generate = list(view_angles.keys())
|
|
2346
|
+
else:
|
|
2347
|
+
views_to_generate = views
|
|
2348
|
+
invalid_views = [v for v in views_to_generate if v not in view_angles]
|
|
2349
|
+
if invalid_views:
|
|
2350
|
+
raise InvalidInputError(
|
|
2351
|
+
f"Invalid view names: {invalid_views}. Valid options: {list(view_angles.keys())}",
|
|
2352
|
+
)
|
|
2353
|
+
|
|
2354
|
+
# Get output directory
|
|
2355
|
+
if output_dir is None:
|
|
2356
|
+
if self._output_directory is None:
|
|
2357
|
+
raise InvalidInputError(
|
|
2358
|
+
"No output directory provided. Set it via constructor or "
|
|
2359
|
+
"set_output_directory() method, or pass it as an argument.",
|
|
2360
|
+
)
|
|
2361
|
+
output_dir = self._output_directory
|
|
2362
|
+
else:
|
|
2363
|
+
output_dir = Path(output_dir)
|
|
2364
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
2365
|
+
|
|
2366
|
+
tract_name = self._extract_tract_name(tract_file)
|
|
2367
|
+
atlas_name = self._extract_tract_name(atlas_file)
|
|
2368
|
+
generated_views: dict[str, Path] = {}
|
|
2369
|
+
|
|
2370
|
+
# Check if all output files already exist BEFORE loading tracts
|
|
2371
|
+
# This prevents unnecessary memory usage when files are already generated
|
|
2372
|
+
all_files_exist = True
|
|
2373
|
+
for view_name in views_to_generate:
|
|
2374
|
+
output_image = output_dir / f"similarity_{tract_name}_vs_{atlas_name}_{view_name}.png"
|
|
2375
|
+
if output_image.exists():
|
|
2376
|
+
generated_views[view_name] = output_image
|
|
2377
|
+
logger.debug("Skipping generation of %s (file already exists)", output_image)
|
|
2378
|
+
else:
|
|
2379
|
+
all_files_exist = False
|
|
2380
|
+
|
|
2381
|
+
# If all files exist, return early without loading tracts
|
|
2382
|
+
if all_files_exist:
|
|
2383
|
+
return generated_views
|
|
2384
|
+
|
|
2385
|
+
try:
|
|
2386
|
+
# Load both tracts only if we need to generate at least one view
|
|
2387
|
+
tract = load_trk(str(tract_file), "same", bbox_valid_check=False)
|
|
2388
|
+
tract.to_rasmm()
|
|
2389
|
+
|
|
2390
|
+
atlas_tract = load_trk(str(atlas_file), "same", bbox_valid_check=False)
|
|
2391
|
+
atlas_tract.to_rasmm()
|
|
2392
|
+
|
|
2393
|
+
# Check if tracts are empty
|
|
2394
|
+
if not tract.streamlines or len(tract.streamlines) == 0:
|
|
2395
|
+
raise InvalidInputError("Subject tract is empty.")
|
|
2396
|
+
if not atlas_tract.streamlines or len(atlas_tract.streamlines) == 0:
|
|
2397
|
+
raise InvalidInputError("Atlas tract is empty.")
|
|
2398
|
+
|
|
2399
|
+
# Load reference images and transform tracts to same space
|
|
2400
|
+
atlas_ref_img_obj = nib.load(str(atlas_ref_img))
|
|
2401
|
+
|
|
2402
|
+
# Transform subject tract to reference space
|
|
2403
|
+
subject_streamlines = transform_streamlines(
|
|
2404
|
+
tract.streamlines,
|
|
2405
|
+
np.linalg.inv(atlas_ref_img_obj.affine), # type: ignore[attr-defined]
|
|
2406
|
+
)
|
|
2407
|
+
|
|
2408
|
+
# No transformation needed, use streamlines directly
|
|
2409
|
+
atlas_streamlines = transform_streamlines(
|
|
2410
|
+
atlas_tract.streamlines,
|
|
2411
|
+
np.linalg.inv(atlas_ref_img_obj.affine), # type: ignore[attr-defined]
|
|
2412
|
+
)
|
|
2413
|
+
if flip_lr:
|
|
2414
|
+
atlas_streamlines = Streamlines(
|
|
2415
|
+
[np.column_stack([-sl[:, 0], sl[:, 1], sl[:, 2]]) for sl in atlas_streamlines],
|
|
2416
|
+
)
|
|
2417
|
+
|
|
2418
|
+
# Calculate combined centroid for rotation (from both tracts)
|
|
2419
|
+
# Calculate combined centroid using utility function
|
|
2420
|
+
centroid = calculate_combined_centroid(subject_streamlines, atlas_streamlines)
|
|
2421
|
+
|
|
2422
|
+
# Generate each requested view
|
|
2423
|
+
for view_name in views_to_generate:
|
|
2424
|
+
output_image = output_dir / f"similarity_{tract_name}_vs_{atlas_name}_{view_name}.png"
|
|
2425
|
+
|
|
2426
|
+
# Skip if file already exists (already added to generated_views above)
|
|
2427
|
+
if output_image.exists():
|
|
2428
|
+
continue
|
|
2429
|
+
|
|
2430
|
+
# Create scene using helper method
|
|
2431
|
+
scene, brain_actor = self._create_scene(ref_img=atlas_ref_img, show_glass_brain=show_glass_brain)
|
|
2432
|
+
|
|
2433
|
+
# Add subject tract with subject color (single color for all streamlines)
|
|
2434
|
+
# Use original streamlines - camera handles view
|
|
2435
|
+
subject_colors = np.tile(subject_color, (len(subject_streamlines), 1))
|
|
2436
|
+
subject_actor = actor.line(subject_streamlines, colors=subject_colors)
|
|
2437
|
+
scene.add(subject_actor)
|
|
2438
|
+
|
|
2439
|
+
# Add atlas tract with atlas color (single color for all streamlines)
|
|
2440
|
+
# Use original streamlines - camera handles view
|
|
2441
|
+
atlas_colors = np.tile(atlas_color, (len(atlas_streamlines), 1))
|
|
2442
|
+
atlas_actor = actor.line(atlas_streamlines, colors=atlas_colors)
|
|
2443
|
+
scene.add(atlas_actor)
|
|
2444
|
+
|
|
2445
|
+
# Set camera position for anatomical view
|
|
2446
|
+
# Use combined bbox of both tracts for camera distance
|
|
2447
|
+
# Calculate combined bbox using utility function
|
|
2448
|
+
bbox_size = calculate_combined_bbox_size(subject_streamlines, atlas_streamlines)
|
|
2449
|
+
|
|
2450
|
+
# Use helper method to set camera
|
|
2451
|
+
self._set_anatomical_camera(
|
|
2452
|
+
scene,
|
|
2453
|
+
centroid,
|
|
2454
|
+
view_name,
|
|
2455
|
+
bbox_size=bbox_size,
|
|
2456
|
+
)
|
|
2457
|
+
|
|
2458
|
+
# Record the scene (this can also fail with std::bad_alloc)
|
|
2459
|
+
try:
|
|
2460
|
+
window.record(
|
|
2461
|
+
scene=scene,
|
|
2462
|
+
out_path=str(output_image),
|
|
2463
|
+
size=figure_size,
|
|
2464
|
+
)
|
|
2465
|
+
generated_views[view_name] = output_image
|
|
2466
|
+
except RuntimeError as e:
|
|
2467
|
+
# Catch std::bad_alloc and other VTK errors during rendering
|
|
2468
|
+
error_msg = str(e).lower()
|
|
2469
|
+
if "bad_alloc" in error_msg or "memory" in error_msg or "allocation" in error_msg:
|
|
2470
|
+
logger.exception(
|
|
2471
|
+
"VTK memory allocation failed during rendering for view %s (likely std::bad_alloc). "
|
|
2472
|
+
"Subject has %d streamlines, atlas has %d streamlines. "
|
|
2473
|
+
"Try reducing image resolution (figure_size) or processing with n_jobs=1.",
|
|
2474
|
+
view_name,
|
|
2475
|
+
len(subject_streamlines),
|
|
2476
|
+
len(atlas_streamlines),
|
|
2477
|
+
)
|
|
2478
|
+
# Clean up and skip this view
|
|
2479
|
+
scene.clear()
|
|
2480
|
+
del subject_actor, atlas_actor
|
|
2481
|
+
if show_glass_brain:
|
|
2482
|
+
del brain_actor
|
|
2483
|
+
gc.collect()
|
|
2484
|
+
continue
|
|
2485
|
+
raise
|
|
2486
|
+
|
|
2487
|
+
# Clean up
|
|
2488
|
+
scene.clear()
|
|
2489
|
+
del subject_actor, atlas_actor
|
|
2490
|
+
if show_glass_brain:
|
|
2491
|
+
del brain_actor
|
|
2492
|
+
gc.collect()
|
|
2493
|
+
|
|
2494
|
+
generated_views[view_name] = output_image
|
|
2495
|
+
|
|
2496
|
+
except TractographyVisualizationError:
|
|
2497
|
+
raise
|
|
2498
|
+
except (OSError, ValueError, RuntimeError) as e:
|
|
2499
|
+
raise TractographyVisualizationError(
|
|
2500
|
+
f"Failed to visualize shape similarity: {e}",
|
|
2501
|
+
) from e
|
|
2502
|
+
else:
|
|
2503
|
+
return generated_views
|
|
2504
|
+
|
|
2505
|
+
def compare_before_after_cci(
|
|
2506
|
+
self,
|
|
2507
|
+
tract_file: str | Path,
|
|
2508
|
+
ref_img: str | Path | None = None,
|
|
2509
|
+
*,
|
|
2510
|
+
views: list[str] | None = None,
|
|
2511
|
+
output_dir: str | Path | None = None,
|
|
2512
|
+
figure_size: tuple[int, int] = (800, 800),
|
|
2513
|
+
show_glass_brain: bool = True,
|
|
2514
|
+
before_color: tuple[float, float, float] = (0.7, 0.7, 0.7), # Light gray
|
|
2515
|
+
after_color: tuple[float, float, float] = (0.0, 0.0, 1.0), # Blue
|
|
2516
|
+
) -> dict[str, Path]:
|
|
2517
|
+
"""Compare tract before and after CCI filtering.
|
|
2518
|
+
|
|
2519
|
+
Generates side-by-side anatomical views (coronal, axial, sagittal) showing
|
|
2520
|
+
the tract before CCI filtering (left) and after CCI filtering (right).
|
|
2521
|
+
|
|
2522
|
+
Parameters
|
|
2523
|
+
----------
|
|
2524
|
+
tract_file : str | Path
|
|
2525
|
+
Path to the tractography file.
|
|
2526
|
+
ref_img : str | Path | None, optional
|
|
2527
|
+
Path to the reference image. If None, uses the reference image
|
|
2528
|
+
set during initialization.
|
|
2529
|
+
views : list[str] | None, optional
|
|
2530
|
+
List of views to generate. Options: "coronal", "axial", "sagittal".
|
|
2531
|
+
If None, generates all three views. Default is None.
|
|
2532
|
+
output_dir : str | Path | None, optional
|
|
2533
|
+
Output directory for generated images. If None, uses the output
|
|
2534
|
+
directory set during initialization.
|
|
2535
|
+
figure_size : tuple[int, int], optional
|
|
2536
|
+
Size of each side of the comparison image in pixels. The final image
|
|
2537
|
+
will be twice as wide. Default is (800, 800).
|
|
2538
|
+
show_glass_brain : bool, optional
|
|
2539
|
+
Whether to show the glass brain outline. Default is True.
|
|
2540
|
+
before_color : tuple[float, float, float], optional
|
|
2541
|
+
RGB color for before CCI filtering tract (0-1 range).
|
|
2542
|
+
Default is (0.7, 0.7, 0.7) (light gray).
|
|
2543
|
+
after_color : tuple[float, float, float], optional
|
|
2544
|
+
RGB color for after CCI filtering tract (0-1 range).
|
|
2545
|
+
Default is (0.0, 0.0, 1.0) (blue).
|
|
2546
|
+
|
|
2547
|
+
Returns
|
|
2548
|
+
-------
|
|
2549
|
+
dict[str, Path]
|
|
2550
|
+
Dictionary mapping view names to their output file paths.
|
|
2551
|
+
Keys: "coronal", "axial", "sagittal".
|
|
2552
|
+
|
|
2553
|
+
Raises
|
|
2554
|
+
------
|
|
2555
|
+
FileNotFoundError
|
|
2556
|
+
If required files are not found.
|
|
2557
|
+
InvalidInputError
|
|
2558
|
+
If no output directory is available or invalid view name.
|
|
2559
|
+
TractographyVisualizationError
|
|
2560
|
+
If comparison fails.
|
|
2561
|
+
|
|
2562
|
+
Examples
|
|
2563
|
+
--------
|
|
2564
|
+
>>> visualizer = TractographyVisualizer(
|
|
2565
|
+
... reference_image="t1w.nii.gz", output_directory="output/"
|
|
2566
|
+
... )
|
|
2567
|
+
>>> views = visualizer.compare_before_after_cci("tract.trk")
|
|
2568
|
+
"""
|
|
2569
|
+
# Use standard anatomical view angles from utils
|
|
2570
|
+
view_angles = ANATOMICAL_VIEW_ANGLES
|
|
2571
|
+
|
|
2572
|
+
# Determine which views to generate
|
|
2573
|
+
if views is None:
|
|
2574
|
+
views_to_generate = list(view_angles.keys())
|
|
2575
|
+
else:
|
|
2576
|
+
views_to_generate = views
|
|
2577
|
+
invalid_views = [v for v in views_to_generate if v not in view_angles]
|
|
2578
|
+
if invalid_views:
|
|
2579
|
+
raise InvalidInputError(
|
|
2580
|
+
f"Invalid view names: {invalid_views}. Valid options: {list(view_angles.keys())}",
|
|
2581
|
+
)
|
|
2582
|
+
|
|
2583
|
+
# Get output directory
|
|
2584
|
+
if output_dir is None:
|
|
2585
|
+
if self._output_directory is None:
|
|
2586
|
+
raise InvalidInputError(
|
|
2587
|
+
"No output directory provided. Set it via constructor or "
|
|
2588
|
+
"set_output_directory() method, or pass it as an argument.",
|
|
2589
|
+
)
|
|
2590
|
+
output_dir = self._output_directory
|
|
2591
|
+
else:
|
|
2592
|
+
output_dir = Path(output_dir)
|
|
2593
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
2594
|
+
|
|
2595
|
+
if ref_img is None:
|
|
2596
|
+
if self._reference_image is None:
|
|
2597
|
+
raise InvalidInputError(
|
|
2598
|
+
"No reference image provided. Set it via constructor or "
|
|
2599
|
+
"set_reference_image() method, or pass it as an argument.",
|
|
2600
|
+
)
|
|
2601
|
+
ref_img = self._reference_image
|
|
2602
|
+
else:
|
|
2603
|
+
ref_img = Path(ref_img)
|
|
2604
|
+
|
|
2605
|
+
tract_name = self._extract_tract_name(tract_file)
|
|
2606
|
+
generated_views: dict[str, Path] = {}
|
|
2607
|
+
|
|
2608
|
+
# Check if all output files already exist BEFORE loading tract
|
|
2609
|
+
# This prevents unnecessary memory usage when files are already generated
|
|
2610
|
+
all_files_exist = True
|
|
2611
|
+
for view_name in views_to_generate:
|
|
2612
|
+
output_image = output_dir / f"cci_before_after_{tract_name}_{view_name}.png"
|
|
2613
|
+
if output_image.exists():
|
|
2614
|
+
generated_views[view_name] = output_image
|
|
2615
|
+
logger.debug("Skipping generation of %s (file already exists)", output_image)
|
|
2616
|
+
else:
|
|
2617
|
+
all_files_exist = False
|
|
2618
|
+
|
|
2619
|
+
# If all files exist, return early without loading tract
|
|
2620
|
+
if all_files_exist:
|
|
2621
|
+
return generated_views
|
|
2622
|
+
|
|
2623
|
+
try:
|
|
2624
|
+
# Load tract only if we need to generate at least one view
|
|
2625
|
+
tract = load_trk(str(tract_file), "same", bbox_valid_check=False)
|
|
2626
|
+
tract.to_rasmm()
|
|
2627
|
+
|
|
2628
|
+
# Calculate CCI and get filtered tract
|
|
2629
|
+
_, filtered_tract = self.calc_cci(tract)
|
|
2630
|
+
|
|
2631
|
+
# Transform both tracts to reference space
|
|
2632
|
+
ref_img_obj = nib.load(str(ref_img))
|
|
2633
|
+
|
|
2634
|
+
# Before CCI filtering: use all streamlines (after length filtering)
|
|
2635
|
+
# We need to get the streamlines that were used for CCI calculation
|
|
2636
|
+
# (i.e., those longer than min_streamline_length)
|
|
2637
|
+
lengths = list(length(tract.streamlines))
|
|
2638
|
+
before_streamlines = Streamlines()
|
|
2639
|
+
for i, sl in enumerate(tract.streamlines):
|
|
2640
|
+
if lengths[i] > self.min_streamline_length:
|
|
2641
|
+
before_streamlines.append(sl)
|
|
2642
|
+
|
|
2643
|
+
before_streamlines = transform_streamlines(
|
|
2644
|
+
before_streamlines,
|
|
2645
|
+
np.linalg.inv(ref_img_obj.affine), # type: ignore[attr-defined]
|
|
2646
|
+
)
|
|
2647
|
+
|
|
2648
|
+
# After CCI filtering: use filtered tract
|
|
2649
|
+
after_streamlines = transform_streamlines(
|
|
2650
|
+
filtered_tract.streamlines,
|
|
2651
|
+
np.linalg.inv(ref_img_obj.affine), # type: ignore[attr-defined]
|
|
2652
|
+
)
|
|
2653
|
+
|
|
2654
|
+
# Calculate combined centroid for rotation (from both tracts)
|
|
2655
|
+
# Calculate combined centroid using utility function
|
|
2656
|
+
centroid = calculate_combined_centroid(before_streamlines, after_streamlines)
|
|
2657
|
+
|
|
2658
|
+
# Generate each requested view
|
|
2659
|
+
for view_name in views_to_generate:
|
|
2660
|
+
output_image = output_dir / f"cci_before_after_{tract_name}_{view_name}.png"
|
|
2661
|
+
|
|
2662
|
+
# Skip if file already exists (already added to generated_views above)
|
|
2663
|
+
if output_image.exists():
|
|
2664
|
+
continue
|
|
2665
|
+
|
|
2666
|
+
# Create side-by-side scenes using helper methods
|
|
2667
|
+
# Left side: Before CCI filtering
|
|
2668
|
+
scene_before, brain_actor_before = self._create_scene(
|
|
2669
|
+
ref_img=ref_img,
|
|
2670
|
+
show_glass_brain=show_glass_brain,
|
|
2671
|
+
)
|
|
2672
|
+
|
|
2673
|
+
# Add before tract (use original streamlines - camera handles view)
|
|
2674
|
+
before_colors = np.tile(before_color, (len(before_streamlines), 1))
|
|
2675
|
+
before_actor = actor.line(before_streamlines, colors=before_colors)
|
|
2676
|
+
scene_before.add(before_actor)
|
|
2677
|
+
|
|
2678
|
+
# Set camera for before scene using utility function
|
|
2679
|
+
bbox_size_before = calculate_bbox_size(before_streamlines)
|
|
2680
|
+
self._set_anatomical_camera(
|
|
2681
|
+
scene_before,
|
|
2682
|
+
centroid,
|
|
2683
|
+
view_name,
|
|
2684
|
+
bbox_size=bbox_size_before,
|
|
2685
|
+
)
|
|
2686
|
+
|
|
2687
|
+
# Right side: After CCI filtering
|
|
2688
|
+
scene_after, brain_actor_after = self._create_scene(ref_img=ref_img, show_glass_brain=show_glass_brain)
|
|
2689
|
+
|
|
2690
|
+
# Add after tract (use original streamlines - camera handles view)
|
|
2691
|
+
after_colors = np.tile(after_color, (len(after_streamlines), 1))
|
|
2692
|
+
after_actor = actor.line(after_streamlines, colors=after_colors)
|
|
2693
|
+
scene_after.add(after_actor)
|
|
2694
|
+
|
|
2695
|
+
# Set camera for after scene using utility function
|
|
2696
|
+
bbox_size_after = calculate_bbox_size(after_streamlines)
|
|
2697
|
+
self._set_anatomical_camera(
|
|
2698
|
+
scene_after,
|
|
2699
|
+
centroid,
|
|
2700
|
+
view_name,
|
|
2701
|
+
bbox_size=bbox_size_after,
|
|
2702
|
+
)
|
|
2703
|
+
|
|
2704
|
+
# Record both scenes to temporary files
|
|
2705
|
+
|
|
2706
|
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_before:
|
|
2707
|
+
tmp_before_path = tmp_before.name
|
|
2708
|
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_after:
|
|
2709
|
+
tmp_after_path = tmp_after.name
|
|
2710
|
+
|
|
2711
|
+
# Record both scenes (this can also fail with std::bad_alloc)
|
|
2712
|
+
try:
|
|
2713
|
+
window.record(
|
|
2714
|
+
scene=scene_before,
|
|
2715
|
+
out_path=tmp_before_path,
|
|
2716
|
+
size=figure_size,
|
|
2717
|
+
)
|
|
2718
|
+
window.record(
|
|
2719
|
+
scene=scene_after,
|
|
2720
|
+
out_path=tmp_after_path,
|
|
2721
|
+
size=figure_size,
|
|
2722
|
+
)
|
|
2723
|
+
except RuntimeError as e:
|
|
2724
|
+
# Catch std::bad_alloc and other VTK errors during rendering
|
|
2725
|
+
error_msg = str(e).lower()
|
|
2726
|
+
if "bad_alloc" in error_msg or "memory" in error_msg or "allocation" in error_msg:
|
|
2727
|
+
logger.exception(
|
|
2728
|
+
"VTK memory allocation failed during rendering for view %s (likely std::bad_alloc). "
|
|
2729
|
+
"Before has %d streamlines, after has %d streamlines. "
|
|
2730
|
+
"Try reducing image resolution (figure_size) or processing with n_jobs=1.",
|
|
2731
|
+
view_name,
|
|
2732
|
+
len(before_streamlines),
|
|
2733
|
+
len(after_streamlines),
|
|
2734
|
+
)
|
|
2735
|
+
# Clean up and skip this view
|
|
2736
|
+
scene_before.clear()
|
|
2737
|
+
scene_after.clear()
|
|
2738
|
+
del before_actor, after_actor
|
|
2739
|
+
if brain_actor_before is not None:
|
|
2740
|
+
del brain_actor_before
|
|
2741
|
+
if brain_actor_after is not None:
|
|
2742
|
+
del brain_actor_after
|
|
2743
|
+
del scene_before, scene_after
|
|
2744
|
+
gc.collect()
|
|
2745
|
+
# Clean up temp files if they were created
|
|
2746
|
+
with contextlib.suppress(OSError):
|
|
2747
|
+
os.unlink(tmp_before_path)
|
|
2748
|
+
os.unlink(tmp_after_path)
|
|
2749
|
+
continue
|
|
2750
|
+
raise
|
|
2751
|
+
|
|
2752
|
+
# Combine images side-by-side using PIL/Pillow or imageio
|
|
2753
|
+
try:
|
|
2754
|
+
img_before = Image.open(tmp_before_path)
|
|
2755
|
+
img_after = Image.open(tmp_after_path)
|
|
2756
|
+
|
|
2757
|
+
# Create combined image (twice as wide)
|
|
2758
|
+
combined_width = figure_size[0] * 2
|
|
2759
|
+
combined_height = figure_size[1]
|
|
2760
|
+
combined_img = Image.new("RGB", (combined_width, combined_height), (255, 255, 255))
|
|
2761
|
+
combined_img.paste(img_before, (0, 0))
|
|
2762
|
+
combined_img.paste(img_after, (figure_size[0], 0))
|
|
2763
|
+
combined_img.save(str(output_image))
|
|
2764
|
+
|
|
2765
|
+
img_before.close()
|
|
2766
|
+
img_after.close()
|
|
2767
|
+
except ImportError:
|
|
2768
|
+
# Fallback to imageio if PIL not available
|
|
2769
|
+
|
|
2770
|
+
img_before = imageio.imread(tmp_before_path)
|
|
2771
|
+
img_after = imageio.imread(tmp_after_path)
|
|
2772
|
+
combined_img = np.concatenate([img_before, img_after], axis=1)
|
|
2773
|
+
imageio.imwrite(str(output_image), combined_img)
|
|
2774
|
+
|
|
2775
|
+
# Clean up temporary files
|
|
2776
|
+
os.unlink(tmp_before_path)
|
|
2777
|
+
os.unlink(tmp_after_path)
|
|
2778
|
+
|
|
2779
|
+
# Clean up scenes
|
|
2780
|
+
scene_before.clear()
|
|
2781
|
+
scene_after.clear()
|
|
2782
|
+
del before_actor, after_actor
|
|
2783
|
+
if show_glass_brain:
|
|
2784
|
+
del brain_actor_before, brain_actor_after
|
|
2785
|
+
gc.collect()
|
|
2786
|
+
|
|
2787
|
+
generated_views[view_name] = output_image
|
|
2788
|
+
|
|
2789
|
+
except TractographyVisualizationError:
|
|
2790
|
+
raise
|
|
2791
|
+
except (OSError, ValueError, RuntimeError) as e:
|
|
2792
|
+
raise TractographyVisualizationError(
|
|
2793
|
+
f"Failed to compare before/after CCI: {e}",
|
|
2794
|
+
) from e
|
|
2795
|
+
else:
|
|
2796
|
+
return generated_views
|
|
2797
|
+
|
|
2798
|
+
def visualize_bundle_assignment(
|
|
2799
|
+
self,
|
|
2800
|
+
tract_file: str | Path,
|
|
2801
|
+
atlas_file: str | Path,
|
|
2802
|
+
ref_img: str | Path | None = None,
|
|
2803
|
+
*,
|
|
2804
|
+
n_segments: int = 100,
|
|
2805
|
+
views: list[str] | None = None,
|
|
2806
|
+
output_dir: str | Path | None = None,
|
|
2807
|
+
figure_size: tuple[int, int] = (800, 800),
|
|
2808
|
+
show_glass_brain: bool = True,
|
|
2809
|
+
colormap: str = "random",
|
|
2810
|
+
) -> dict[str, Path]:
|
|
2811
|
+
"""Visualize bundle assignment map using DIPY's assignment_map.
|
|
2812
|
+
|
|
2813
|
+
Assigns each streamline in the target tract to a segment of the model
|
|
2814
|
+
bundle using DIPY's assignment_map function. Color-codes streamlines
|
|
2815
|
+
by their assigned segment. Generates anatomical views showing which
|
|
2816
|
+
streamlines belong to which bundle segment.
|
|
2817
|
+
|
|
2818
|
+
Parameters
|
|
2819
|
+
----------
|
|
2820
|
+
tract_file : str | Path
|
|
2821
|
+
Path to the target tractography file (streamlines to assign).
|
|
2822
|
+
atlas_file : str | Path
|
|
2823
|
+
Path to the atlas tractography file (reference bundle for assignment).
|
|
2824
|
+
ref_img : str | Path | None, optional
|
|
2825
|
+
Path to the reference image. If None, uses the reference image
|
|
2826
|
+
set during initialization.
|
|
2827
|
+
n_segments : int, optional
|
|
2828
|
+
Number of segments to divide the model bundle into. Each streamline
|
|
2829
|
+
in the target tract will be assigned to the closest segment.
|
|
2830
|
+
Default is 100.
|
|
2831
|
+
views : list[str] | None, optional
|
|
2832
|
+
List of views to generate. Options: "coronal", "axial", "sagittal".
|
|
2833
|
+
If None, generates all three views. Default is None.
|
|
2834
|
+
output_dir : str | Path | None, optional
|
|
2835
|
+
Output directory for generated images. If None, uses the output
|
|
2836
|
+
directory set during initialization.
|
|
2837
|
+
figure_size : tuple[int, int], optional
|
|
2838
|
+
Size of the output images in pixels. Default is (800, 800).
|
|
2839
|
+
show_glass_brain : bool, optional
|
|
2840
|
+
Whether to show the glass brain outline. Default is True.
|
|
2841
|
+
colormap : str, optional
|
|
2842
|
+
Name of the colormap to use for segment colors. Should be a
|
|
2843
|
+
discrete colormap (e.g., "tab20", "Set3", "Paired").
|
|
2844
|
+
Default is "tab20".
|
|
2845
|
+
|
|
2846
|
+
Returns
|
|
2847
|
+
-------
|
|
2848
|
+
dict[str, Path]
|
|
2849
|
+
Dictionary mapping view names to their output file paths.
|
|
2850
|
+
Keys: "coronal", "axial", "sagittal".
|
|
2851
|
+
|
|
2852
|
+
Raises
|
|
2853
|
+
------
|
|
2854
|
+
FileNotFoundError
|
|
2855
|
+
If required files are not found.
|
|
2856
|
+
InvalidInputError
|
|
2857
|
+
If no output directory is available or invalid view name.
|
|
2858
|
+
TractographyVisualizationError
|
|
2859
|
+
If visualization fails.
|
|
2860
|
+
|
|
2861
|
+
Examples
|
|
2862
|
+
--------
|
|
2863
|
+
>>> visualizer = TractographyVisualizer(
|
|
2864
|
+
... reference_image="t1w.nii.gz", output_directory="output/"
|
|
2865
|
+
... )
|
|
2866
|
+
>>> views = visualizer.visualize_bundle_assignment(
|
|
2867
|
+
... "target_tract.trk", "model_tract.trk", n_segments=100
|
|
2868
|
+
... )
|
|
2869
|
+
"""
|
|
2870
|
+
# Use standard anatomical view angles from utils
|
|
2871
|
+
view_angles = ANATOMICAL_VIEW_ANGLES
|
|
2872
|
+
rgb = (2, 3)
|
|
2873
|
+
|
|
2874
|
+
# Determine which views to generate
|
|
2875
|
+
if views is None:
|
|
2876
|
+
views_to_generate = list(view_angles.keys())
|
|
2877
|
+
else:
|
|
2878
|
+
views_to_generate = views
|
|
2879
|
+
invalid_views = [v for v in views_to_generate if v not in view_angles]
|
|
2880
|
+
if invalid_views:
|
|
2881
|
+
raise InvalidInputError(
|
|
2882
|
+
f"Invalid view names: {invalid_views}. Valid options: {list(view_angles.keys())}",
|
|
2883
|
+
)
|
|
2884
|
+
|
|
2885
|
+
# Get output directory
|
|
2886
|
+
if output_dir is None:
|
|
2887
|
+
if self._output_directory is None:
|
|
2888
|
+
raise InvalidInputError(
|
|
2889
|
+
"No output directory provided. Set it via constructor or "
|
|
2890
|
+
"set_output_directory() method, or pass it as an argument.",
|
|
2891
|
+
)
|
|
2892
|
+
output_dir = self._output_directory
|
|
2893
|
+
else:
|
|
2894
|
+
output_dir = Path(output_dir)
|
|
2895
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
2896
|
+
|
|
2897
|
+
if ref_img is None:
|
|
2898
|
+
if self._reference_image is None:
|
|
2899
|
+
raise InvalidInputError(
|
|
2900
|
+
"No reference image provided. Set it via constructor or "
|
|
2901
|
+
"set_reference_image() method, or pass it as an argument.",
|
|
2902
|
+
)
|
|
2903
|
+
ref_img = self._reference_image
|
|
2904
|
+
else:
|
|
2905
|
+
ref_img = Path(ref_img)
|
|
2906
|
+
|
|
2907
|
+
tract_name = self._extract_tract_name(tract_file)
|
|
2908
|
+
generated_views: dict[str, Path] = {}
|
|
2909
|
+
|
|
2910
|
+
# Check if all output files already exist BEFORE loading tracts
|
|
2911
|
+
# This prevents unnecessary memory usage when files are already generated
|
|
2912
|
+
all_files_exist = True
|
|
2913
|
+
for view_name in views_to_generate:
|
|
2914
|
+
output_image = output_dir / f"bundle_assignment_{tract_name}_{view_name}.png"
|
|
2915
|
+
if output_image.exists():
|
|
2916
|
+
generated_views[view_name] = output_image
|
|
2917
|
+
logger.debug("Skipping generation of %s (file already exists)", output_image)
|
|
2918
|
+
else:
|
|
2919
|
+
all_files_exist = False
|
|
2920
|
+
|
|
2921
|
+
# If all files exist, return early without loading tracts
|
|
2922
|
+
if all_files_exist:
|
|
2923
|
+
return generated_views
|
|
2924
|
+
|
|
2925
|
+
try:
|
|
2926
|
+
# Load both tracts only if we need to generate at least one view
|
|
2927
|
+
tract = load_trk(str(tract_file), "same", bbox_valid_check=False)
|
|
2928
|
+
tract.to_rasmm()
|
|
2929
|
+
|
|
2930
|
+
atlas_tract = load_trk(str(atlas_file), "same", bbox_valid_check=False)
|
|
2931
|
+
atlas_tract.to_rasmm()
|
|
2932
|
+
|
|
2933
|
+
if not tract.streamlines or len(tract.streamlines) == 0:
|
|
2934
|
+
raise InvalidInputError("Target tractogram is empty.")
|
|
2935
|
+
if not atlas_tract.streamlines or len(atlas_tract.streamlines) == 0:
|
|
2936
|
+
raise InvalidInputError("Atlas tractogram is empty.")
|
|
2937
|
+
|
|
2938
|
+
# Transform both tracts to reference space for assignment
|
|
2939
|
+
# assignment_map requires both tracts to be in the same coordinate space
|
|
2940
|
+
ref_img_obj = nib.load(str(ref_img))
|
|
2941
|
+
tract_streamlines = transform_streamlines(
|
|
2942
|
+
tract.streamlines,
|
|
2943
|
+
np.linalg.inv(ref_img_obj.affine), # type: ignore[attr-defined]
|
|
2944
|
+
)
|
|
2945
|
+
atlas_streamlines = transform_streamlines(
|
|
2946
|
+
atlas_tract.streamlines,
|
|
2947
|
+
np.linalg.inv(ref_img_obj.affine), # type: ignore[attr-defined]
|
|
2948
|
+
)
|
|
2949
|
+
|
|
2950
|
+
# Calculate assignments on transformed streamlines (both in same space)
|
|
2951
|
+
# This ensures consistent colors across all rotations
|
|
2952
|
+
# Use assignment_map to assign target streamlines to model bundle segments
|
|
2953
|
+
# assignment_map returns per-point assignments (one assignment per point)
|
|
2954
|
+
# Assignments are in the order points appear when iterating through streamlines sequentially
|
|
2955
|
+
assignment_indices = assignment_map(tract_streamlines, atlas_streamlines, n_segments)
|
|
2956
|
+
assignment_indices = np.array(assignment_indices)
|
|
2957
|
+
|
|
2958
|
+
# Generate colors for each segment
|
|
2959
|
+
# Use random colors like DIPY example (produces wide spectrum of distinct colors)
|
|
2960
|
+
# or use a colormap if specified (e.g., "Spectral", "hsv", "rainbow" for wide spectrum)
|
|
2961
|
+
if colormap == "random" or colormap is None:
|
|
2962
|
+
# Match DIPY example: use random colors for maximum color differentiation
|
|
2963
|
+
# This produces a wide spectrum of distinct colors like the example image
|
|
2964
|
+
rng = np.random.default_rng()
|
|
2965
|
+
segment_colors = [tuple(rng.random(3)) for si in range(n_segments)]
|
|
2966
|
+
else:
|
|
2967
|
+
# Use specified colormap
|
|
2968
|
+
# For wide spectrum, consider: "Spectral", "hsv", "rainbow", "turbo"
|
|
2969
|
+
segment_colors_array = create_colormap(
|
|
2970
|
+
np.linspace(0, 1, n_segments),
|
|
2971
|
+
name=colormap,
|
|
2972
|
+
)
|
|
2973
|
+
# Convert to RGB (0-1 range), remove alpha if present
|
|
2974
|
+
segment_colors_array = (
|
|
2975
|
+
segment_colors_array[:, : rgb[1]]
|
|
2976
|
+
if segment_colors_array.shape[1] > rgb[1]
|
|
2977
|
+
else segment_colors_array
|
|
2978
|
+
)
|
|
2979
|
+
# Convert to list of tuples for compatibility
|
|
2980
|
+
segment_colors = [tuple(segment_colors_array[i]) for i in range(n_segments)]
|
|
2981
|
+
|
|
2982
|
+
# Create per-point colors based on assignment (matching DIPY example pattern)
|
|
2983
|
+
# Convert to list of tuples as in DIPY example to ensure proper color application
|
|
2984
|
+
# This is calculated once and reused for all rotations
|
|
2985
|
+
# Each point gets the color corresponding to its segment assignment
|
|
2986
|
+
# This creates the banding effect as points along a streamline are assigned to different segments
|
|
2987
|
+
# IMPORTANT: assignment_indices are in the order points appear when iterating through
|
|
2988
|
+
# streamlines sequentially (streamline 0 all points, then streamline 1 all points, etc.)
|
|
2989
|
+
point_colors = []
|
|
2990
|
+
for i in range(len(assignment_indices)):
|
|
2991
|
+
# Get the color for this point's assigned segment
|
|
2992
|
+
seg_idx = assignment_indices[i]
|
|
2993
|
+
# Ensure index is within bounds
|
|
2994
|
+
if seg_idx < 0 or seg_idx >= len(segment_colors):
|
|
2995
|
+
# Fallback to first color if index is out of bounds
|
|
2996
|
+
point_colors.append(segment_colors[0])
|
|
2997
|
+
else:
|
|
2998
|
+
point_colors.append(tuple(segment_colors[seg_idx]))
|
|
2999
|
+
|
|
3000
|
+
# Debug: Show sample of assignments and colors for first streamline (after point_colors is created)
|
|
3001
|
+
point_idx = 0
|
|
3002
|
+
first_sl = tract_streamlines[0]
|
|
3003
|
+
first_sl_assignments = assignment_indices[point_idx : point_idx + len(first_sl)]
|
|
3004
|
+
first_sl_colors = point_colors[point_idx : point_idx + len(first_sl)]
|
|
3005
|
+
logger.debug(
|
|
3006
|
+
"First streamline: %d points, %d unique segments",
|
|
3007
|
+
len(first_sl),
|
|
3008
|
+
len(np.unique(first_sl_assignments)),
|
|
3009
|
+
)
|
|
3010
|
+
logger.debug("Sample assignments (first 10): %s", first_sl_assignments[:10])
|
|
3011
|
+
logger.debug("Sample colors (first 3): %s", first_sl_colors[:3])
|
|
3012
|
+
|
|
3013
|
+
# Get centroid using utility function
|
|
3014
|
+
centroid = calculate_centroid(tract_streamlines)
|
|
3015
|
+
gc.collect()
|
|
3016
|
+
|
|
3017
|
+
# Generate each requested view
|
|
3018
|
+
for view_name in views_to_generate:
|
|
3019
|
+
output_image = output_dir / f"bundle_assignment_{tract_name}_{view_name}.png"
|
|
3020
|
+
|
|
3021
|
+
# Skip if file already exists (already added to generated_views above)
|
|
3022
|
+
if output_image.exists():
|
|
3023
|
+
continue
|
|
3024
|
+
|
|
3025
|
+
# Create scene using helper method
|
|
3026
|
+
scene, brain_actor = self._create_scene(ref_img=ref_img, show_glass_brain=show_glass_brain)
|
|
3027
|
+
|
|
3028
|
+
# Convert colors to numpy array format (N x 3) for actor.line
|
|
3029
|
+
point_colors_array = np.array(point_colors, dtype=np.float32)
|
|
3030
|
+
|
|
3031
|
+
# Ensure the array has the correct shape (N points x 3 RGB values)
|
|
3032
|
+
if point_colors_array.ndim != rgb[0] or point_colors_array.shape[1] != rgb[1]:
|
|
3033
|
+
|
|
3034
|
+
def _create_shape_error(error_msg: str) -> ValueError:
|
|
3035
|
+
return ValueError(error_msg)
|
|
3036
|
+
|
|
3037
|
+
msg = (
|
|
3038
|
+
f"point_colors_array has incorrect shape: {point_colors_array.shape}. "
|
|
3039
|
+
f"Expected (N, 3) where N is the number of points."
|
|
3040
|
+
)
|
|
3041
|
+
raise _create_shape_error(msg)
|
|
3042
|
+
|
|
3043
|
+
# Use original streamlines with original colors - no rotation needed
|
|
3044
|
+
tract_actor = actor.line(tract_streamlines, colors=point_colors_array, fake_tube=True, linewidth=6)
|
|
3045
|
+
scene.add(tract_actor)
|
|
3046
|
+
|
|
3047
|
+
# Set camera position for anatomical view (no streamline rotation needed)
|
|
3048
|
+
# Calculate bbox for camera distance using utility function
|
|
3049
|
+
bbox_size = calculate_bbox_size(tract_streamlines)
|
|
3050
|
+
|
|
3051
|
+
# Use helper method to set camera
|
|
3052
|
+
self._set_anatomical_camera(
|
|
3053
|
+
scene,
|
|
3054
|
+
centroid,
|
|
3055
|
+
view_name,
|
|
3056
|
+
bbox_size=bbox_size,
|
|
3057
|
+
)
|
|
3058
|
+
|
|
3059
|
+
# Record the scene (this can also fail with std::bad_alloc)
|
|
3060
|
+
try:
|
|
3061
|
+
window.record(
|
|
3062
|
+
scene=scene,
|
|
3063
|
+
out_path=str(output_image),
|
|
3064
|
+
size=figure_size,
|
|
3065
|
+
)
|
|
3066
|
+
generated_views[view_name] = output_image
|
|
3067
|
+
except RuntimeError as e:
|
|
3068
|
+
# Catch std::bad_alloc and other VTK errors during rendering
|
|
3069
|
+
error_msg = str(e).lower()
|
|
3070
|
+
if "bad_alloc" in error_msg or "memory" in error_msg or "allocation" in error_msg:
|
|
3071
|
+
logger.exception(
|
|
3072
|
+
"VTK memory allocation failed during rendering for view %s (likely std::bad_alloc). "
|
|
3073
|
+
"Tract has %d streamlines with %d total points. "
|
|
3074
|
+
"Try reducing image resolution (figure_size) or processing with n_jobs=1.",
|
|
3075
|
+
view_name,
|
|
3076
|
+
len(tract_streamlines),
|
|
3077
|
+
sum(len(sl) for sl in tract_streamlines),
|
|
3078
|
+
)
|
|
3079
|
+
# Clean up and skip this view
|
|
3080
|
+
scene.clear()
|
|
3081
|
+
del tract_actor
|
|
3082
|
+
if show_glass_brain:
|
|
3083
|
+
del brain_actor
|
|
3084
|
+
gc.collect()
|
|
3085
|
+
continue
|
|
3086
|
+
raise
|
|
3087
|
+
|
|
3088
|
+
# Clean up
|
|
3089
|
+
scene.clear()
|
|
3090
|
+
del tract_actor
|
|
3091
|
+
if show_glass_brain:
|
|
3092
|
+
del brain_actor
|
|
3093
|
+
gc.collect()
|
|
3094
|
+
|
|
3095
|
+
generated_views[view_name] = output_image
|
|
3096
|
+
|
|
3097
|
+
except (OSError, ValueError, RuntimeError, IndexError) as e:
|
|
3098
|
+
raise TractographyVisualizationError(
|
|
3099
|
+
f"Failed to visualize bundle assignment: {e}",
|
|
3100
|
+
) from e
|
|
3101
|
+
else:
|
|
3102
|
+
return generated_views
|
|
3103
|
+
|
|
3104
|
+
def generate_gif(
|
|
3105
|
+
self,
|
|
3106
|
+
name: str,
|
|
3107
|
+
tract_file: str | Path,
|
|
3108
|
+
ref_img: str | Path | None = None,
|
|
3109
|
+
*,
|
|
3110
|
+
ref_file: str | Path | None = None, # Alias for ref_img
|
|
3111
|
+
output_dir: str | Path | None = None,
|
|
3112
|
+
) -> Path:
|
|
3113
|
+
"""Generate a GIF animation of rotating tractography.
|
|
3114
|
+
|
|
3115
|
+
Parameters
|
|
3116
|
+
----------
|
|
3117
|
+
name : str
|
|
3118
|
+
Base name for the output GIF file (without extension).
|
|
3119
|
+
tract_file : str | Path
|
|
3120
|
+
Path to the tractography file.
|
|
3121
|
+
ref_img : str | Path | None, optional
|
|
3122
|
+
Path to the reference image. If None, uses the reference image
|
|
3123
|
+
set during initialization.
|
|
3124
|
+
output_dir : str | Path | None, optional
|
|
3125
|
+
Output directory. If None, uses the output directory set during
|
|
3126
|
+
initialization or creates a default directory.
|
|
3127
|
+
|
|
3128
|
+
Returns
|
|
3129
|
+
-------
|
|
3130
|
+
Path
|
|
3131
|
+
Path to the generated GIF file.
|
|
3132
|
+
|
|
3133
|
+
Raises
|
|
3134
|
+
------
|
|
3135
|
+
FileNotFoundError
|
|
3136
|
+
If required files are not found.
|
|
3137
|
+
InvalidInputError
|
|
3138
|
+
If no output directory is available.
|
|
3139
|
+
TractographyVisualizationError
|
|
3140
|
+
If GIF generation fails.
|
|
3141
|
+
"""
|
|
3142
|
+
if ref_img is None:
|
|
3143
|
+
if self._reference_image is None:
|
|
3144
|
+
raise InvalidInputError(
|
|
3145
|
+
"No reference image provided. Set it via constructor or "
|
|
3146
|
+
"set_reference_image() method, or pass it as an argument.",
|
|
3147
|
+
)
|
|
3148
|
+
ref_img = self._reference_image
|
|
3149
|
+
else:
|
|
3150
|
+
ref_img = Path(ref_img)
|
|
3151
|
+
|
|
3152
|
+
if output_dir is None:
|
|
3153
|
+
if self._output_directory is None:
|
|
3154
|
+
raise InvalidInputError(
|
|
3155
|
+
"No output directory provided. Set it via constructor or "
|
|
3156
|
+
"set_output_directory() method, or pass it as an argument.",
|
|
3157
|
+
)
|
|
3158
|
+
output_dir = self._output_directory
|
|
3159
|
+
else:
|
|
3160
|
+
output_dir = Path(output_dir)
|
|
3161
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
3162
|
+
|
|
3163
|
+
gif_filename = output_dir / f"{name}.gif"
|
|
3164
|
+
|
|
3165
|
+
# Skip if file already exists
|
|
3166
|
+
if gif_filename.exists():
|
|
3167
|
+
logger.debug("Skipping generation of %s (file already exists)", gif_filename)
|
|
3168
|
+
return gif_filename
|
|
3169
|
+
|
|
3170
|
+
# Support both ref_img and ref_file (alias) for backward compatibility
|
|
3171
|
+
if ref_img is None and ref_file is not None:
|
|
3172
|
+
ref_img = ref_file
|
|
3173
|
+
|
|
3174
|
+
try:
|
|
3175
|
+
# Load tract and get streamlines for rotation
|
|
3176
|
+
tract = load_trk(str(tract_file), "same", bbox_valid_check=False)
|
|
3177
|
+
tract.to_rasmm()
|
|
3178
|
+
ref_img_obj = nib.load(str(ref_img))
|
|
3179
|
+
tract_streamlines = transform_streamlines(
|
|
3180
|
+
tract.streamlines,
|
|
3181
|
+
np.linalg.inv(ref_img_obj.affine), # type: ignore[attr-defined]
|
|
3182
|
+
)
|
|
3183
|
+
|
|
3184
|
+
# Create initial actors
|
|
3185
|
+
tract_actor = actor.line(tract_streamlines)
|
|
3186
|
+
brain_actor = self.get_glass_brain(ref_img)
|
|
3187
|
+
scene = window.Scene()
|
|
3188
|
+
scene.add(brain_actor)
|
|
3189
|
+
scene.add(tract_actor)
|
|
3190
|
+
scene.setBackground(color=(1, 1, 1))
|
|
3191
|
+
|
|
3192
|
+
angles = np.linspace(0, 360, self.gif_frames, endpoint=False)
|
|
3193
|
+
rotation_axis = np.array([0, 0, 1]) # Rotate around Z-axis
|
|
3194
|
+
rotation_center = np.array([0, 0, 0])
|
|
3195
|
+
gif_frames = []
|
|
3196
|
+
|
|
3197
|
+
for angle in angles:
|
|
3198
|
+
rot_matrix = Rotation.from_rotvec(
|
|
3199
|
+
angle * np.pi / 180 * rotation_axis,
|
|
3200
|
+
).as_matrix()
|
|
3201
|
+
|
|
3202
|
+
# Rotate streamlines around the origin
|
|
3203
|
+
rotated_streamlines = [
|
|
3204
|
+
np.dot(s - rotation_center, rot_matrix.T) + rotation_center for s in tract_streamlines
|
|
3205
|
+
]
|
|
3206
|
+
stream_actor = actor.line(rotated_streamlines)
|
|
3207
|
+
|
|
3208
|
+
# Convert to 4x4 transformation matrix
|
|
3209
|
+
transform_matrix = np.eye(4)
|
|
3210
|
+
transform_matrix[:3, :3] = rot_matrix
|
|
3211
|
+
transform_matrix[:3, 3] = rotation_center - np.dot(
|
|
3212
|
+
rot_matrix,
|
|
3213
|
+
rotation_center,
|
|
3214
|
+
)
|
|
3215
|
+
|
|
3216
|
+
# Rotate glass brain using VTK transform
|
|
3217
|
+
transform = vtk.vtkTransform()
|
|
3218
|
+
transform.Concatenate(transform_matrix.flatten())
|
|
3219
|
+
|
|
3220
|
+
brain_actor.SetUserTransform(transform)
|
|
3221
|
+
|
|
3222
|
+
# Clear scene and re-add actors
|
|
3223
|
+
scene.clear()
|
|
3224
|
+
scene.add(stream_actor)
|
|
3225
|
+
scene.add(brain_actor)
|
|
3226
|
+
|
|
3227
|
+
# Snapshot can also fail with std::bad_alloc
|
|
3228
|
+
try:
|
|
3229
|
+
frame = window.snapshot(scene, size=self.gif_size)
|
|
3230
|
+
gif_frames.append(frame)
|
|
3231
|
+
except RuntimeError as e:
|
|
3232
|
+
# Catch std::bad_alloc and other VTK errors during snapshot
|
|
3233
|
+
error_msg = str(e).lower()
|
|
3234
|
+
if "bad_alloc" in error_msg or "memory" in error_msg or "allocation" in error_msg:
|
|
3235
|
+
logger.exception(
|
|
3236
|
+
"VTK memory allocation failed during GIF frame snapshot (likely std::bad_alloc). "
|
|
3237
|
+
"Tract has %d streamlines with %d total points. "
|
|
3238
|
+
"Try reducing gif_size or processing with n_jobs=1.",
|
|
3239
|
+
len(tract_streamlines),
|
|
3240
|
+
sum(len(sl) for sl in tract_streamlines),
|
|
3241
|
+
)
|
|
3242
|
+
# Clean up and break out of loop
|
|
3243
|
+
scene.clear()
|
|
3244
|
+
del stream_actor, rotated_streamlines
|
|
3245
|
+
gc.collect()
|
|
3246
|
+
# If we have some frames, save what we have
|
|
3247
|
+
if gif_frames:
|
|
3248
|
+
logger.warning("Saving partial GIF with %d frames", len(gif_frames))
|
|
3249
|
+
else:
|
|
3250
|
+
# No frames captured, skip this view
|
|
3251
|
+
del gif_frames, tract_actor, brain_actor, scene
|
|
3252
|
+
gc.collect()
|
|
3253
|
+
continue
|
|
3254
|
+
raise
|
|
3255
|
+
|
|
3256
|
+
# Force memory cleanup
|
|
3257
|
+
del stream_actor, rotated_streamlines
|
|
3258
|
+
gc.collect()
|
|
3259
|
+
|
|
3260
|
+
# Clean up scene and actors after all frames are captured
|
|
3261
|
+
scene.clear()
|
|
3262
|
+
del tract_actor, brain_actor, scene
|
|
3263
|
+
del tract, tract_streamlines, ref_img_obj
|
|
3264
|
+
|
|
3265
|
+
# Save as optimized GIF
|
|
3266
|
+
imageio.mimsave(
|
|
3267
|
+
str(gif_filename),
|
|
3268
|
+
gif_frames,
|
|
3269
|
+
duration=self.gif_duration,
|
|
3270
|
+
palettesize=self.gif_palette_size,
|
|
3271
|
+
)
|
|
3272
|
+
|
|
3273
|
+
# Clean up frames after saving
|
|
3274
|
+
del gif_frames
|
|
3275
|
+
gc.collect()
|
|
3276
|
+
except TractographyVisualizationError:
|
|
3277
|
+
raise
|
|
3278
|
+
except (OSError, ValueError, RuntimeError) as e:
|
|
3279
|
+
raise TractographyVisualizationError(
|
|
3280
|
+
f"Failed to generate GIF: {e}",
|
|
3281
|
+
) from e
|
|
3282
|
+
else:
|
|
3283
|
+
return gif_filename
|
|
3284
|
+
|
|
3285
|
+
def convert_gif_to_mp4(
|
|
3286
|
+
self,
|
|
3287
|
+
gif_path: str | Path,
|
|
3288
|
+
mp4_path: str | Path | None = None,
|
|
3289
|
+
*,
|
|
3290
|
+
fps: int = 10,
|
|
3291
|
+
) -> Path:
|
|
3292
|
+
"""Convert a GIF file to MP4 format.
|
|
3293
|
+
|
|
3294
|
+
Parameters
|
|
3295
|
+
----------
|
|
3296
|
+
gif_path : str | Path
|
|
3297
|
+
Path to the input GIF file.
|
|
3298
|
+
mp4_path : str | Path | None, optional
|
|
3299
|
+
Path to the output MP4 file. If None, uses the same name as GIF
|
|
3300
|
+
with .mp4 extension.
|
|
3301
|
+
fps : int, optional
|
|
3302
|
+
Frames per second for the video. Default is 10.
|
|
3303
|
+
|
|
3304
|
+
Returns
|
|
3305
|
+
-------
|
|
3306
|
+
Path
|
|
3307
|
+
Path to the generated MP4 file.
|
|
3308
|
+
|
|
3309
|
+
Raises
|
|
3310
|
+
------
|
|
3311
|
+
FileNotFoundError
|
|
3312
|
+
If the GIF file is not found.
|
|
3313
|
+
TractographyVisualizationError
|
|
3314
|
+
If conversion fails.
|
|
3315
|
+
"""
|
|
3316
|
+
gif_path_obj = Path(gif_path)
|
|
3317
|
+
if mp4_path is None:
|
|
3318
|
+
mp4_path = gif_path_obj.with_suffix(".mp4")
|
|
3319
|
+
else:
|
|
3320
|
+
mp4_path = Path(mp4_path)
|
|
3321
|
+
mp4_path.parent.mkdir(parents=True, exist_ok=True)
|
|
3322
|
+
|
|
3323
|
+
# Skip if file already exists
|
|
3324
|
+
if mp4_path.exists():
|
|
3325
|
+
logger.debug("Skipping conversion of %s to %s (file already exists)", gif_path, mp4_path)
|
|
3326
|
+
return mp4_path
|
|
3327
|
+
|
|
3328
|
+
try:
|
|
3329
|
+
reader = imageio.get_reader(str(gif_path_obj))
|
|
3330
|
+
writer = imageio.get_writer(
|
|
3331
|
+
str(mp4_path),
|
|
3332
|
+
format="FFMPEG", # type: ignore[arg-type]
|
|
3333
|
+
fps=fps,
|
|
3334
|
+
codec="libx264",
|
|
3335
|
+
)
|
|
3336
|
+
|
|
3337
|
+
for frame in reader: # type: ignore[attr-defined]
|
|
3338
|
+
writer.append_data(frame)
|
|
3339
|
+
|
|
3340
|
+
writer.close()
|
|
3341
|
+
reader.close()
|
|
3342
|
+
except (OSError, ValueError, RuntimeError) as e:
|
|
3343
|
+
raise TractographyVisualizationError(
|
|
3344
|
+
f"Failed to convert GIF to MP4: {e}",
|
|
3345
|
+
) from e
|
|
3346
|
+
else:
|
|
3347
|
+
return mp4_path
|
|
3348
|
+
|
|
3349
|
+
def generate_videos(
|
|
3350
|
+
self,
|
|
3351
|
+
tract_files: list[str | Path],
|
|
3352
|
+
ref_img: str | Path | None = None,
|
|
3353
|
+
*,
|
|
3354
|
+
ref_file: str | Path | None = None, # Alias for ref_img
|
|
3355
|
+
output_dir: str | Path | None = None,
|
|
3356
|
+
remove_gifs: bool = True,
|
|
3357
|
+
) -> dict[str, Path]:
|
|
3358
|
+
"""Generate MP4 videos from multiple tractography files.
|
|
3359
|
+
|
|
3360
|
+
Parameters
|
|
3361
|
+
----------
|
|
3362
|
+
tract_files : list[str | Path]
|
|
3363
|
+
List of paths to tractography files.
|
|
3364
|
+
ref_img : str | Path | None, optional
|
|
3365
|
+
Path to the reference image. If None, uses the reference image
|
|
3366
|
+
set during initialization.
|
|
3367
|
+
output_dir : str | Path | None, optional
|
|
3368
|
+
Output directory. If None, uses the output directory set during
|
|
3369
|
+
initialization.
|
|
3370
|
+
remove_gifs : bool, optional
|
|
3371
|
+
Whether to remove intermediate GIF files. Default is True.
|
|
3372
|
+
|
|
3373
|
+
Returns
|
|
3374
|
+
-------
|
|
3375
|
+
dict[str, Path]
|
|
3376
|
+
Dictionary mapping tract names to their MP4 file paths.
|
|
3377
|
+
|
|
3378
|
+
Raises
|
|
3379
|
+
------
|
|
3380
|
+
FileNotFoundError
|
|
3381
|
+
If required files are not found.
|
|
3382
|
+
InvalidInputError
|
|
3383
|
+
If the tract_files list is empty.
|
|
3384
|
+
TractographyVisualizationError
|
|
3385
|
+
If video generation fails.
|
|
3386
|
+
"""
|
|
3387
|
+
if not tract_files:
|
|
3388
|
+
raise InvalidInputError("No tract files provided.")
|
|
3389
|
+
|
|
3390
|
+
# Support both ref_img and ref_file (alias) for backward compatibility
|
|
3391
|
+
if ref_img is None and ref_file is not None:
|
|
3392
|
+
ref_img = ref_file
|
|
3393
|
+
|
|
3394
|
+
if output_dir is None:
|
|
3395
|
+
if self._output_directory is None:
|
|
3396
|
+
raise InvalidInputError(
|
|
3397
|
+
"No output directory provided. Set it via constructor or "
|
|
3398
|
+
"set_output_directory() method, or pass it as an argument.",
|
|
3399
|
+
)
|
|
3400
|
+
output_dir = self._output_directory
|
|
3401
|
+
else:
|
|
3402
|
+
output_dir = Path(output_dir)
|
|
3403
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
3404
|
+
|
|
3405
|
+
if ref_img is None:
|
|
3406
|
+
if self._reference_image is None:
|
|
3407
|
+
raise InvalidInputError(
|
|
3408
|
+
"No reference image provided. Set it via constructor or "
|
|
3409
|
+
"set_reference_image() method, or pass it as an argument.",
|
|
3410
|
+
)
|
|
3411
|
+
ref_img = self._reference_image
|
|
3412
|
+
else:
|
|
3413
|
+
ref_img = Path(ref_img)
|
|
3414
|
+
|
|
3415
|
+
tract_videos: dict[str, Path] = {}
|
|
3416
|
+
|
|
3417
|
+
for tract_file in tract_files:
|
|
3418
|
+
try:
|
|
3419
|
+
tract_path = Path(tract_file)
|
|
3420
|
+
tract_name = tract_path.stem
|
|
3421
|
+
tract_mp4 = output_dir / f"{tract_name}.mp4"
|
|
3422
|
+
|
|
3423
|
+
# Skip if MP4 already exists
|
|
3424
|
+
if tract_mp4.exists():
|
|
3425
|
+
logger.debug("Skipping generation of %s (file already exists)", tract_mp4)
|
|
3426
|
+
tract_videos[tract_name] = tract_mp4
|
|
3427
|
+
continue
|
|
3428
|
+
|
|
3429
|
+
tract_gif = self.generate_gif(
|
|
3430
|
+
name=tract_name,
|
|
3431
|
+
tract_file=tract_path,
|
|
3432
|
+
ref_img=ref_img,
|
|
3433
|
+
output_dir=output_dir,
|
|
3434
|
+
)
|
|
3435
|
+
self.convert_gif_to_mp4(tract_gif, tract_mp4)
|
|
3436
|
+
|
|
3437
|
+
if remove_gifs:
|
|
3438
|
+
tract_gif.unlink()
|
|
3439
|
+
|
|
3440
|
+
tract_videos[tract_name] = tract_mp4
|
|
3441
|
+
except (OSError, ValueError, RuntimeError) as e:
|
|
3442
|
+
raise TractographyVisualizationError(
|
|
3443
|
+
f"Failed to generate video for {tract_file}: {e}",
|
|
3444
|
+
) from e
|
|
3445
|
+
|
|
3446
|
+
return tract_videos
|
|
3447
|
+
|
|
3448
|
+
def _process_single_tract(
|
|
3449
|
+
self,
|
|
3450
|
+
subject_id: str,
|
|
3451
|
+
tract_name: str,
|
|
3452
|
+
tract_file: str | Path,
|
|
3453
|
+
subject_ref_img: Path,
|
|
3454
|
+
tract_output_dir: Path,
|
|
3455
|
+
subjects_mni_space: dict[str, dict[str, str | Path]] | None,
|
|
3456
|
+
atlas_files: dict[str, str | Path] | None,
|
|
3457
|
+
metric_files: dict[str, dict[str, str | Path]] | None,
|
|
3458
|
+
atlas_ref_img: str | Path | None,
|
|
3459
|
+
*,
|
|
3460
|
+
flip_lr: bool,
|
|
3461
|
+
skip_checks: list[str],
|
|
3462
|
+
**kwargs: Any,
|
|
3463
|
+
) -> dict[str, str | Path]:
|
|
3464
|
+
"""Process a single subject/tract combination.
|
|
3465
|
+
|
|
3466
|
+
This is a helper method used for parallel processing.
|
|
3467
|
+
Returns a dictionary of results for this tract.
|
|
3468
|
+
"""
|
|
3469
|
+
tract_results: dict[str, str | Path] = {}
|
|
3470
|
+
|
|
3471
|
+
try:
|
|
3472
|
+
# 1. Standard anatomical views
|
|
3473
|
+
if "anatomical_views" not in skip_checks:
|
|
3474
|
+
try:
|
|
3475
|
+
anatomical_views = self.generate_anatomical_views(
|
|
3476
|
+
tract_file,
|
|
3477
|
+
output_dir=tract_output_dir,
|
|
3478
|
+
ref_img=subject_ref_img,
|
|
3479
|
+
**kwargs,
|
|
3480
|
+
)
|
|
3481
|
+
# Add anatomical views to results
|
|
3482
|
+
for view_name, view_path in anatomical_views.items():
|
|
3483
|
+
tract_results[f"anatomical_{view_name}"] = view_path
|
|
3484
|
+
except (OSError, ValueError, RuntimeError) as e:
|
|
3485
|
+
logger.warning(
|
|
3486
|
+
"Failed to generate anatomical views for %s/%s: %s",
|
|
3487
|
+
subject_id,
|
|
3488
|
+
tract_name,
|
|
3489
|
+
e,
|
|
3490
|
+
)
|
|
3491
|
+
|
|
3492
|
+
# 2. CCI calculation and visualization
|
|
3493
|
+
if "cci" not in skip_checks:
|
|
3494
|
+
try:
|
|
3495
|
+
cci_plots = self.plot_cci(
|
|
3496
|
+
tract_file, # type: ignore[arg-type]
|
|
3497
|
+
output_dir=tract_output_dir,
|
|
3498
|
+
ref_img=subject_ref_img,
|
|
3499
|
+
**kwargs,
|
|
3500
|
+
)
|
|
3501
|
+
# Add CCI plots to results
|
|
3502
|
+
for plot_type, plot_path in cci_plots.items():
|
|
3503
|
+
tract_results[f"cci_{plot_type}"] = plot_path
|
|
3504
|
+
except (OSError, ValueError, RuntimeError) as e:
|
|
3505
|
+
logger.warning(
|
|
3506
|
+
"Failed to generate CCI plots for %s/%s: %s",
|
|
3507
|
+
subject_id,
|
|
3508
|
+
tract_name,
|
|
3509
|
+
e,
|
|
3510
|
+
)
|
|
3511
|
+
|
|
3512
|
+
# 3. Before/after CCI comparison
|
|
3513
|
+
if "before_after_cci" not in skip_checks:
|
|
3514
|
+
try:
|
|
3515
|
+
before_after_views = self.compare_before_after_cci(
|
|
3516
|
+
tract_file,
|
|
3517
|
+
output_dir=tract_output_dir,
|
|
3518
|
+
ref_img=subject_ref_img,
|
|
3519
|
+
**kwargs,
|
|
3520
|
+
)
|
|
3521
|
+
# Add before/after views to results
|
|
3522
|
+
for view_name, view_path in before_after_views.items():
|
|
3523
|
+
tract_results[f"before_after_cci_{view_name}"] = view_path
|
|
3524
|
+
except (OSError, ValueError, RuntimeError) as e:
|
|
3525
|
+
logger.warning(
|
|
3526
|
+
"Failed to generate before/after CCI comparison for %s/%s: %s",
|
|
3527
|
+
subject_id,
|
|
3528
|
+
tract_name,
|
|
3529
|
+
e,
|
|
3530
|
+
)
|
|
3531
|
+
|
|
3532
|
+
# 4. Atlas comparison (uses MNI space tracts for subject views)
|
|
3533
|
+
if "atlas_comparison" not in skip_checks and atlas_files is not None and tract_name in atlas_files:
|
|
3534
|
+
try:
|
|
3535
|
+
atlas_file = atlas_files[tract_name]
|
|
3536
|
+
# Generate subject views in MNI space if available, otherwise skip
|
|
3537
|
+
if (
|
|
3538
|
+
subjects_mni_space is not None
|
|
3539
|
+
and subject_id in subjects_mni_space
|
|
3540
|
+
and tract_name in subjects_mni_space[subject_id]
|
|
3541
|
+
):
|
|
3542
|
+
tract_file_mni = subjects_mni_space[subject_id][tract_name]
|
|
3543
|
+
# Generate subject views in MNI space
|
|
3544
|
+
subject_mni_views = self.generate_anatomical_views(
|
|
3545
|
+
tract_file_mni,
|
|
3546
|
+
ref_img=atlas_ref_img, # Use atlas ref image for MNI space
|
|
3547
|
+
output_dir=tract_output_dir,
|
|
3548
|
+
**kwargs,
|
|
3549
|
+
)
|
|
3550
|
+
# Add subject MNI views to results
|
|
3551
|
+
for view_name, view_path in subject_mni_views.items():
|
|
3552
|
+
tract_results[f"subject_mni_{view_name}"] = view_path
|
|
3553
|
+
|
|
3554
|
+
# Generate atlas views
|
|
3555
|
+
atlas_views = self.generate_atlas_views(
|
|
3556
|
+
atlas_file,
|
|
3557
|
+
atlas_ref_img=atlas_ref_img,
|
|
3558
|
+
flip_lr=flip_lr,
|
|
3559
|
+
output_dir=tract_output_dir,
|
|
3560
|
+
**kwargs,
|
|
3561
|
+
)
|
|
3562
|
+
# Add atlas views to results
|
|
3563
|
+
for view_name, view_path in atlas_views.items():
|
|
3564
|
+
tract_results[f"atlas_{view_name}"] = view_path
|
|
3565
|
+
except (OSError, ValueError, RuntimeError) as e:
|
|
3566
|
+
logger.warning(
|
|
3567
|
+
"Failed to generate atlas comparison views for %s/%s: %s",
|
|
3568
|
+
subject_id,
|
|
3569
|
+
tract_name,
|
|
3570
|
+
e,
|
|
3571
|
+
)
|
|
3572
|
+
|
|
3573
|
+
# 5. Shape similarity (uses MNI space tracts)
|
|
3574
|
+
if (
|
|
3575
|
+
"shape_similarity" not in skip_checks and atlas_files is not None and subjects_mni_space is not None
|
|
3576
|
+
) and (
|
|
3577
|
+
subject_id in subjects_mni_space
|
|
3578
|
+
and tract_name in subjects_mni_space[subject_id]
|
|
3579
|
+
and tract_name in atlas_files
|
|
3580
|
+
):
|
|
3581
|
+
try:
|
|
3582
|
+
# Use MNI space tract for shape similarity
|
|
3583
|
+
tract_file_mni = subjects_mni_space[subject_id][tract_name]
|
|
3584
|
+
atlas_file = atlas_files[tract_name]
|
|
3585
|
+
# Calculate shape similarity score
|
|
3586
|
+
similarity_score = self.calculate_shape_similarity(
|
|
3587
|
+
tract_file_mni,
|
|
3588
|
+
atlas_file,
|
|
3589
|
+
atlas_ref_img=atlas_ref_img,
|
|
3590
|
+
flip_lr=flip_lr,
|
|
3591
|
+
**kwargs,
|
|
3592
|
+
)
|
|
3593
|
+
tract_results["shape_similarity_score"] = str(similarity_score)
|
|
3594
|
+
|
|
3595
|
+
# Visualize shape similarity
|
|
3596
|
+
similarity_views = self.visualize_shape_similarity(
|
|
3597
|
+
tract_file_mni,
|
|
3598
|
+
atlas_file,
|
|
3599
|
+
atlas_ref_img=atlas_ref_img,
|
|
3600
|
+
flip_lr=flip_lr,
|
|
3601
|
+
output_dir=tract_output_dir,
|
|
3602
|
+
**kwargs,
|
|
3603
|
+
)
|
|
3604
|
+
# Add similarity views to results
|
|
3605
|
+
for view_name, view_path in similarity_views.items():
|
|
3606
|
+
tract_results[f"similarity_{view_name}"] = view_path
|
|
3607
|
+
except (OSError, ValueError, RuntimeError, IndexError) as e:
|
|
3608
|
+
logger.warning(
|
|
3609
|
+
"Failed to calculate/visualize shape similarity for %s/%s: %s",
|
|
3610
|
+
subject_id,
|
|
3611
|
+
tract_name,
|
|
3612
|
+
e,
|
|
3613
|
+
)
|
|
3614
|
+
|
|
3615
|
+
# 6. AFQ profile (requires metric files per subject and atlas files as model files)
|
|
3616
|
+
if ("afq_profile" not in skip_checks and metric_files is not None and atlas_files is not None) and (
|
|
3617
|
+
subject_id in metric_files and tract_name in atlas_files
|
|
3618
|
+
):
|
|
3619
|
+
model_file = atlas_files[tract_name] # Use atlas file as model file
|
|
3620
|
+
for metric_name, metric_file in metric_files[subject_id].items():
|
|
3621
|
+
try:
|
|
3622
|
+
afq_plots = self.plot_afq(
|
|
3623
|
+
metric_file,
|
|
3624
|
+
metric_name,
|
|
3625
|
+
tract_file,
|
|
3626
|
+
model_file,
|
|
3627
|
+
ref_img=subject_ref_img,
|
|
3628
|
+
output_dir=tract_output_dir,
|
|
3629
|
+
**kwargs,
|
|
3630
|
+
)
|
|
3631
|
+
# Add AFQ plots to results
|
|
3632
|
+
for plot_type, plot_path in afq_plots.items():
|
|
3633
|
+
tract_results[f"afq_{metric_name}_{plot_type}"] = plot_path
|
|
3634
|
+
except (OSError, ValueError, RuntimeError, IndexError) as e:
|
|
3635
|
+
logger.warning(
|
|
3636
|
+
"Failed to generate AFQ profile for %s/%s/%s: %s",
|
|
3637
|
+
subject_id,
|
|
3638
|
+
tract_name,
|
|
3639
|
+
metric_name,
|
|
3640
|
+
e,
|
|
3641
|
+
)
|
|
3642
|
+
|
|
3643
|
+
# 7. Bundle assignment (uses MNI space tracts and atlas files as model files)
|
|
3644
|
+
if (
|
|
3645
|
+
"bundle_assignment" not in skip_checks and atlas_files is not None and subjects_mni_space is not None
|
|
3646
|
+
) and (
|
|
3647
|
+
subject_id in subjects_mni_space
|
|
3648
|
+
and tract_name in subjects_mni_space[subject_id]
|
|
3649
|
+
and tract_name in atlas_files
|
|
3650
|
+
):
|
|
3651
|
+
try:
|
|
3652
|
+
# Use MNI space tract for bundle assignment
|
|
3653
|
+
tract_file_mni = subjects_mni_space[subject_id][tract_name]
|
|
3654
|
+
model_file = atlas_files[tract_name] # Use atlas file as model file
|
|
3655
|
+
assignment_views = self.visualize_bundle_assignment(
|
|
3656
|
+
tract_file_mni,
|
|
3657
|
+
model_file,
|
|
3658
|
+
output_dir=tract_output_dir,
|
|
3659
|
+
ref_img=subject_ref_img,
|
|
3660
|
+
**kwargs,
|
|
3661
|
+
)
|
|
3662
|
+
# Add assignment views to results
|
|
3663
|
+
for view_name, view_path in assignment_views.items():
|
|
3664
|
+
tract_results[f"assignment_{view_name}"] = view_path
|
|
3665
|
+
except (OSError, ValueError, RuntimeError, IndexError) as e:
|
|
3666
|
+
logger.warning(
|
|
3667
|
+
"Failed to generate bundle assignment for %s/%s: %s",
|
|
3668
|
+
subject_id,
|
|
3669
|
+
tract_name,
|
|
3670
|
+
e,
|
|
3671
|
+
)
|
|
3672
|
+
|
|
3673
|
+
except (OSError, ValueError, RuntimeError):
|
|
3674
|
+
logger.exception("Error processing %s/%s", subject_id, tract_name)
|
|
3675
|
+
finally:
|
|
3676
|
+
# Clean up memory after processing this tract
|
|
3677
|
+
gc.collect()
|
|
3678
|
+
|
|
3679
|
+
return tract_results
|
|
3680
|
+
|
|
3681
|
+
def run_quality_check_workflow(
|
|
3682
|
+
self,
|
|
3683
|
+
subjects_original_space: dict[str, dict[str, str | Path]],
|
|
3684
|
+
ref_img: str | Path | dict[str, str | Path] | None = None,
|
|
3685
|
+
*,
|
|
3686
|
+
subjects_mni_space: dict[str, dict[str, str | Path]] | None = None,
|
|
3687
|
+
atlas_files: dict[str, str | Path] | None = None,
|
|
3688
|
+
metric_files: dict[str, dict[str, str | Path]] | None = None,
|
|
3689
|
+
atlas_ref_img: str | Path | None = None,
|
|
3690
|
+
flip_lr: bool = False,
|
|
3691
|
+
output_dir: str | Path | None = None,
|
|
3692
|
+
html_output: str | Path | None = None,
|
|
3693
|
+
skip_checks: list[str] | None = None,
|
|
3694
|
+
n_jobs: int | None = None,
|
|
3695
|
+
**kwargs: Any,
|
|
3696
|
+
) -> dict[str, dict[str, dict[str, str | Path]]]:
|
|
3697
|
+
"""Run comprehensive quality checks for multiple subjects and tracts.
|
|
3698
|
+
|
|
3699
|
+
This workflow function orchestrates all available quality check methods
|
|
3700
|
+
for each subject/tract combination and generates an HTML report.
|
|
3701
|
+
|
|
3702
|
+
Different quality checks require tracts in different coordinate spaces:
|
|
3703
|
+
- **Original space** (subjects_original_space): Used for anatomical views,
|
|
3704
|
+
CCI calculations, before/after CCI comparison, and AFQ profiles
|
|
3705
|
+
(which need to align with subject-specific metric files).
|
|
3706
|
+
- **MNI/Atlas space** (subjects_mni_space): Used for shape similarity
|
|
3707
|
+
calculations and bundle assignment (which compare with atlas files).
|
|
3708
|
+
|
|
3709
|
+
Parameters
|
|
3710
|
+
----------
|
|
3711
|
+
subjects_original_space : dict[str, dict[str, str | Path]]
|
|
3712
|
+
Dictionary mapping subject IDs to their tract files in original/native space.
|
|
3713
|
+
Format: {subject_id: {tract_name: tract_file_path}}
|
|
3714
|
+
Example:
|
|
3715
|
+
{
|
|
3716
|
+
"sub-001": {
|
|
3717
|
+
"AF_L": "path/to/sub-001_AF_L_original.trk",
|
|
3718
|
+
"AF_R": "path/to/sub-001_AF_R_original.trk"
|
|
3719
|
+
},
|
|
3720
|
+
"sub-002": {
|
|
3721
|
+
"AF_L": "path/to/sub-002_AF_L_original.trk"
|
|
3722
|
+
}
|
|
3723
|
+
}
|
|
3724
|
+
Used for: anatomical views, CCI, before/after CCI, AFQ profiles.
|
|
3725
|
+
ref_img : str | Path | dict[str, str | Path] | None, optional
|
|
3726
|
+
Reference image(s) for subjects. Can be:
|
|
3727
|
+
- A single path (str | Path): Used for all subjects
|
|
3728
|
+
- A dictionary mapping subject IDs to reference images:
|
|
3729
|
+
{subject_id: ref_image_path}
|
|
3730
|
+
- None: Uses the reference image set during initialization or via
|
|
3731
|
+
`set_reference_image()` for all subjects
|
|
3732
|
+
Example for per-subject reference images:
|
|
3733
|
+
{"sub-001": "path/to/sub-001_t1w.nii.gz", "sub-002": "path/to/sub-002_t1w.nii.gz"}
|
|
3734
|
+
subjects_mni_space : dict[str, dict[str, str | Path]] | None, optional
|
|
3735
|
+
Dictionary mapping subject IDs to their tract files in MNI/atlas space.
|
|
3736
|
+
Format: {subject_id: {tract_name: tract_file_path}}
|
|
3737
|
+
Example:
|
|
3738
|
+
{
|
|
3739
|
+
"sub-001": {
|
|
3740
|
+
"AF_L": "path/to/sub-001_AF_L_mni.trk",
|
|
3741
|
+
"AF_R": "path/to/sub-001_AF_R_mni.trk"
|
|
3742
|
+
},
|
|
3743
|
+
"sub-002": {
|
|
3744
|
+
"AF_L": "path/to/sub-002_AF_L_mni.trk"
|
|
3745
|
+
}
|
|
3746
|
+
}
|
|
3747
|
+
Used for: shape similarity, bundle assignment.
|
|
3748
|
+
If None, these checks will be skipped.
|
|
3749
|
+
atlas_files : dict[str, str | Path] | None, optional
|
|
3750
|
+
Dictionary mapping tract names to their corresponding atlas/model files.
|
|
3751
|
+
These files are shared across all subjects and used for:
|
|
3752
|
+
- Atlas comparison visualizations
|
|
3753
|
+
- Shape similarity calculations
|
|
3754
|
+
- AFQ profile calculations (as model files)
|
|
3755
|
+
- Bundle assignment visualizations (as model files)
|
|
3756
|
+
Format: {tract_name: atlas_file_path}
|
|
3757
|
+
Example: {"AF_L": "path/to/atlas_AF_L.trk", "AF_R": "path/to/atlas_AF_R.trk"}
|
|
3758
|
+
metric_files : dict[str, dict[str, str | Path]] | None, optional
|
|
3759
|
+
Dictionary mapping subject IDs to their metric files.
|
|
3760
|
+
All tracts within a subject will use the same metric files.
|
|
3761
|
+
Format: {subject_id: {metric_name: metric_file_path}}
|
|
3762
|
+
Example:
|
|
3763
|
+
{
|
|
3764
|
+
"sub-001": {"FA": "path/to/sub-001_FA.nii.gz", "MD": "path/to/sub-001_MD.nii.gz"},
|
|
3765
|
+
"sub-002": {"FA": "path/to/sub-002_FA.nii.gz"}
|
|
3766
|
+
}
|
|
3767
|
+
If provided, AFQ profile calculations will be run for all tracts in each subject.
|
|
3768
|
+
atlas_ref_img : str | Path | None, optional
|
|
3769
|
+
Path to the reference image matching the atlas coordinate space
|
|
3770
|
+
(e.g., MNI template). Required if atlas files are in a different
|
|
3771
|
+
coordinate space than subject tracts.
|
|
3772
|
+
flip_lr : bool, optional
|
|
3773
|
+
Whether to flip left-right (X-axis) when transforming atlas.
|
|
3774
|
+
Default is False.
|
|
3775
|
+
output_dir : str | Path | None, optional
|
|
3776
|
+
Output directory for generated files. If None, uses the output
|
|
3777
|
+
directory set during initialization.
|
|
3778
|
+
html_output : str | Path | None, optional
|
|
3779
|
+
Path for the HTML report file. If None, creates "quality_check_report.html"
|
|
3780
|
+
in the output directory.
|
|
3781
|
+
skip_checks : list[str] | None, optional
|
|
3782
|
+
List of quality checks to skip. Valid options:
|
|
3783
|
+
- "anatomical_views": Skip standard anatomical views
|
|
3784
|
+
- "atlas_comparison": Skip atlas comparison views
|
|
3785
|
+
- "cci": Skip CCI calculation and visualization
|
|
3786
|
+
- "before_after_cci": Skip before/after CCI comparison
|
|
3787
|
+
- "afq_profile": Skip AFQ profile visualization
|
|
3788
|
+
- "bundle_assignment": Skip bundle assignment visualization
|
|
3789
|
+
- "shape_similarity": Skip shape similarity calculation and visualization
|
|
3790
|
+
n_jobs : int | None, optional
|
|
3791
|
+
Number of parallel jobs to run for processing multiple subjects/tracts.
|
|
3792
|
+
If None, uses the value set during initialization (default: 1).
|
|
3793
|
+
Use -1 to automatically determine optimal number based on available
|
|
3794
|
+
resources (respects SLURM allocations and OpenMP thread settings to
|
|
3795
|
+
prevent oversubscription). Only effective when processing multiple
|
|
3796
|
+
subjects/tracts. Default is None.
|
|
3797
|
+
|
|
3798
|
+
Note: When running under SLURM, -1 will automatically use
|
|
3799
|
+
SLURM_CPUS_PER_TASK or SLURM_JOB_CPUS_PER_NODE. If OMP_NUM_THREADS is set,
|
|
3800
|
+
it will divide the available CPUs by the number of OpenMP threads to
|
|
3801
|
+
prevent resource contention.
|
|
3802
|
+
**kwargs
|
|
3803
|
+
Additional keyword arguments passed to individual quality check methods.
|
|
3804
|
+
|
|
3805
|
+
Returns
|
|
3806
|
+
-------
|
|
3807
|
+
dict[str, dict[str, dict[str, str | Path]]]
|
|
3808
|
+
Nested dictionary structure: {subject_id: {tract_name: {media_type: file_path}}}
|
|
3809
|
+
This structure is compatible with `create_quality_check_html()`.
|
|
3810
|
+
|
|
3811
|
+
Raises
|
|
3812
|
+
------
|
|
3813
|
+
InvalidInputError
|
|
3814
|
+
If required files are missing or invalid.
|
|
3815
|
+
TractographyVisualizationError
|
|
3816
|
+
If quality check workflow fails.
|
|
3817
|
+
|
|
3818
|
+
Examples
|
|
3819
|
+
--------
|
|
3820
|
+
Single subject or all subjects share same reference image:
|
|
3821
|
+
>>> visualizer = TractographyVisualizer(output_directory="output/")
|
|
3822
|
+
>>> subjects_original = {
|
|
3823
|
+
... "sub-001": {
|
|
3824
|
+
... "AF_L": "sub-001_AF_L_original.trk",
|
|
3825
|
+
... "AF_R": "sub-001_AF_R_original.trk",
|
|
3826
|
+
... }
|
|
3827
|
+
... }
|
|
3828
|
+
>>> results = visualizer.run_quality_check_workflow(
|
|
3829
|
+
... subjects_original_space=subjects_original,
|
|
3830
|
+
... ref_img="shared_t1w.nii.gz", # Single image for all subjects
|
|
3831
|
+
... html_output="quality_report.html",
|
|
3832
|
+
... )
|
|
3833
|
+
|
|
3834
|
+
Multiple subjects with different reference images:
|
|
3835
|
+
>>> visualizer = TractographyVisualizer(output_directory="output/")
|
|
3836
|
+
>>> subjects_original = {
|
|
3837
|
+
... "sub-001": {"AF_L": "sub-001_AF_L_original.trk"},
|
|
3838
|
+
... "sub-002": {"AF_L": "sub-002_AF_L_original.trk"},
|
|
3839
|
+
... }
|
|
3840
|
+
>>> # Each subject has its own reference image
|
|
3841
|
+
>>> ref_images = {
|
|
3842
|
+
... "sub-001": "sub-001_t1w.nii.gz",
|
|
3843
|
+
... "sub-002": "sub-002_t1w.nii.gz",
|
|
3844
|
+
... }
|
|
3845
|
+
>>> subjects_mni = {
|
|
3846
|
+
... "sub-001": {"AF_L": "sub-001_AF_L_mni.trk"},
|
|
3847
|
+
... "sub-002": {"AF_L": "sub-002_AF_L_mni.trk"},
|
|
3848
|
+
... }
|
|
3849
|
+
>>> atlas_files = {"AF_L": "atlas_AF_L.trk"}
|
|
3850
|
+
>>> metric_files = {
|
|
3851
|
+
... "sub-001": {"FA": "sub-001_FA.nii.gz"},
|
|
3852
|
+
... "sub-002": {"FA": "sub-002_FA.nii.gz"},
|
|
3853
|
+
... }
|
|
3854
|
+
>>> results = visualizer.run_quality_check_workflow(
|
|
3855
|
+
... subjects_original_space=subjects_original,
|
|
3856
|
+
... ref_img=ref_images, # Dictionary mapping subject_id -> ref_image
|
|
3857
|
+
... subjects_mni_space=subjects_mni,
|
|
3858
|
+
... atlas_files=atlas_files,
|
|
3859
|
+
... metric_files=metric_files,
|
|
3860
|
+
... html_output="quality_report.html",
|
|
3861
|
+
... )
|
|
3862
|
+
"""
|
|
3863
|
+
# Initialize XVFB (X Virtual Framebuffer) if requested for headless environments
|
|
3864
|
+
vdisplay = None
|
|
3865
|
+
if os.environ.get("XVFB", "").lower() in ("1", "true", "yes"):
|
|
3866
|
+
logger.info("Initializing XVFB for headless rendering")
|
|
3867
|
+
vdisplay = Xvfb()
|
|
3868
|
+
vdisplay.start()
|
|
3869
|
+
|
|
3870
|
+
# Get output directory
|
|
3871
|
+
if output_dir is None:
|
|
3872
|
+
if self._output_directory is None:
|
|
3873
|
+
raise InvalidInputError(
|
|
3874
|
+
"No output directory provided. Set it via constructor or "
|
|
3875
|
+
"set_output_directory() method, or pass it as an argument.",
|
|
3876
|
+
)
|
|
3877
|
+
output_dir = self._output_directory
|
|
3878
|
+
else:
|
|
3879
|
+
output_dir = Path(output_dir)
|
|
3880
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
3881
|
+
|
|
3882
|
+
# Handle reference image(s) - can be single path or dict mapping subject_id -> path
|
|
3883
|
+
if ref_img is None:
|
|
3884
|
+
if self._reference_image is None:
|
|
3885
|
+
raise InvalidInputError(
|
|
3886
|
+
"No reference image provided. Set it via constructor or "
|
|
3887
|
+
"set_reference_image() method, or pass it as an argument.",
|
|
3888
|
+
)
|
|
3889
|
+
# Use instance reference image for all subjects
|
|
3890
|
+
subject_ref_imgs: dict[str, Path] = dict.fromkeys(subjects_original_space.keys(), self._reference_image)
|
|
3891
|
+
elif isinstance(ref_img, dict):
|
|
3892
|
+
# Dictionary mapping subject IDs to reference images
|
|
3893
|
+
subject_ref_imgs = {subject_id: Path(path) for subject_id, path in ref_img.items()}
|
|
3894
|
+
# Validate all subjects have reference images
|
|
3895
|
+
missing = set(subjects_original_space.keys()) - set(subject_ref_imgs.keys())
|
|
3896
|
+
if missing:
|
|
3897
|
+
raise InvalidInputError(
|
|
3898
|
+
f"Missing reference images for subjects: {', '.join(sorted(missing))}",
|
|
3899
|
+
)
|
|
3900
|
+
else:
|
|
3901
|
+
# Single reference image for all subjects
|
|
3902
|
+
single_ref_img = Path(ref_img)
|
|
3903
|
+
subject_ref_imgs = dict.fromkeys(subjects_original_space.keys(), single_ref_img)
|
|
3904
|
+
|
|
3905
|
+
# Set default skip_checks
|
|
3906
|
+
if skip_checks is None:
|
|
3907
|
+
skip_checks = []
|
|
3908
|
+
|
|
3909
|
+
# Determine number of jobs to use
|
|
3910
|
+
if n_jobs is None:
|
|
3911
|
+
n_jobs = self.n_jobs
|
|
3912
|
+
elif n_jobs == -1:
|
|
3913
|
+
# Use optimal n_jobs considering SLURM and OpenMP settings
|
|
3914
|
+
base_n_jobs = _get_optimal_n_jobs()
|
|
3915
|
+
# Further reduce if memory is limited
|
|
3916
|
+
n_jobs = _get_n_jobs_with_memory_limit(
|
|
3917
|
+
base_n_jobs,
|
|
3918
|
+
estimated_memory_per_job_mb=2000.0, # ~2GB per worker
|
|
3919
|
+
safety_margin=0.2,
|
|
3920
|
+
)
|
|
3921
|
+
else:
|
|
3922
|
+
n_jobs = max(1, n_jobs)
|
|
3923
|
+
# Still check memory even if n_jobs is explicitly set
|
|
3924
|
+
if psutil is not None:
|
|
3925
|
+
n_jobs = _get_n_jobs_with_memory_limit(
|
|
3926
|
+
n_jobs,
|
|
3927
|
+
estimated_memory_per_job_mb=2000.0,
|
|
3928
|
+
safety_margin=0.2,
|
|
3929
|
+
)
|
|
3930
|
+
|
|
3931
|
+
# Log the final n_jobs value for debugging
|
|
3932
|
+
logger.debug("Using n_jobs=%d for parallel processing", n_jobs)
|
|
3933
|
+
|
|
3934
|
+
# Initialize results dictionary
|
|
3935
|
+
results: dict[str, dict[str, dict[str, str | Path]]] = {}
|
|
3936
|
+
|
|
3937
|
+
# Prepare all tasks (subject_id, tract_name, tract_file combinations)
|
|
3938
|
+
tasks: list[tuple[str, str, str | Path, Path, Path]] = []
|
|
3939
|
+
for subject_id, tracts in subjects_original_space.items():
|
|
3940
|
+
subject_ref_img = subject_ref_imgs[subject_id]
|
|
3941
|
+
subject_output_dir = output_dir / subject_id
|
|
3942
|
+
subject_output_dir.mkdir(parents=True, exist_ok=True)
|
|
3943
|
+
|
|
3944
|
+
for tract_name, tract_file in tracts.items():
|
|
3945
|
+
tract_path = Path(tract_file)
|
|
3946
|
+
if not tract_path.exists():
|
|
3947
|
+
raise FileNotFoundError(f"Tract file not found: {tract_file}")
|
|
3948
|
+
|
|
3949
|
+
tract_output_dir = subject_output_dir / tract_name
|
|
3950
|
+
tract_output_dir.mkdir(parents=True, exist_ok=True)
|
|
3951
|
+
|
|
3952
|
+
tasks.append((subject_id, tract_name, tract_file, subject_ref_img, tract_output_dir))
|
|
3953
|
+
|
|
3954
|
+
# Prepare visualizer parameters for worker processes
|
|
3955
|
+
visualizer_params = {
|
|
3956
|
+
"gif_size": self.gif_size,
|
|
3957
|
+
"gif_duration": self.gif_duration,
|
|
3958
|
+
"gif_palette_size": self.gif_palette_size,
|
|
3959
|
+
"gif_frames": self.gif_frames,
|
|
3960
|
+
"min_streamline_length": self.min_streamline_length,
|
|
3961
|
+
"cci_threshold": self.cci_threshold,
|
|
3962
|
+
"afq_resample_points": self.afq_resample_points,
|
|
3963
|
+
"n_jobs": 1, # Workers don't need parallelization
|
|
3964
|
+
}
|
|
3965
|
+
|
|
3966
|
+
# Process tasks in parallel or sequentially
|
|
3967
|
+
if n_jobs > 1 and len(tasks) > 1:
|
|
3968
|
+
logger.info("Processing %d tracts using %d workers", len(tasks), n_jobs)
|
|
3969
|
+
try:
|
|
3970
|
+
with ProcessPoolExecutor(max_workers=n_jobs) as executor:
|
|
3971
|
+
futures = {
|
|
3972
|
+
executor.submit(
|
|
3973
|
+
_process_tract_worker,
|
|
3974
|
+
subject_id=subject_id,
|
|
3975
|
+
tract_name=tract_name,
|
|
3976
|
+
tract_file=tract_file,
|
|
3977
|
+
subject_ref_img=subject_ref_img,
|
|
3978
|
+
tract_output_dir=tract_output_dir,
|
|
3979
|
+
subjects_mni_space=subjects_mni_space,
|
|
3980
|
+
atlas_files=atlas_files,
|
|
3981
|
+
metric_files=metric_files,
|
|
3982
|
+
atlas_ref_img=atlas_ref_img,
|
|
3983
|
+
flip_lr=flip_lr,
|
|
3984
|
+
skip_checks=skip_checks,
|
|
3985
|
+
visualizer_params=visualizer_params,
|
|
3986
|
+
**kwargs,
|
|
3987
|
+
): (subject_id, tract_name)
|
|
3988
|
+
for subject_id, tract_name, tract_file, subject_ref_img, tract_output_dir in tasks
|
|
3989
|
+
}
|
|
3990
|
+
|
|
3991
|
+
# Collect results as they complete
|
|
3992
|
+
broken_pool_detected = False
|
|
3993
|
+
completed_count = 0
|
|
3994
|
+
try:
|
|
3995
|
+
for future in as_completed(futures):
|
|
3996
|
+
subject_id, tract_name = futures[future]
|
|
3997
|
+
try:
|
|
3998
|
+
# Get result with timeout - this is where BrokenProcessPool is raised
|
|
3999
|
+
result_subject_id, result_tract_name, tract_results = future.result(timeout=None)
|
|
4000
|
+
if result_subject_id not in results:
|
|
4001
|
+
results[result_subject_id] = {}
|
|
4002
|
+
results[result_subject_id][result_tract_name] = tract_results
|
|
4003
|
+
# Clean up the future reference
|
|
4004
|
+
del tract_results
|
|
4005
|
+
except RuntimeError as e:
|
|
4006
|
+
# RuntimeError catches BrokenProcessPool (which is a subclass of RuntimeError)
|
|
4007
|
+
# Check if this is a BrokenProcessPool by checking the error message and type name
|
|
4008
|
+
error_type = type(e).__name__
|
|
4009
|
+
error_msg = str(e)
|
|
4010
|
+
if (
|
|
4011
|
+
"BrokenProcessPool" in error_type
|
|
4012
|
+
or "BrokenProcessPool" in error_msg
|
|
4013
|
+
or "process pool" in error_msg.lower()
|
|
4014
|
+
or "was terminated abruptly" in error_msg
|
|
4015
|
+
):
|
|
4016
|
+
broken_pool_detected = True
|
|
4017
|
+
logger.exception(
|
|
4018
|
+
"BrokenProcessPool detected for %s/%s. Worker process may have crashed",
|
|
4019
|
+
subject_id,
|
|
4020
|
+
tract_name,
|
|
4021
|
+
)
|
|
4022
|
+
# Mark this task as failed
|
|
4023
|
+
if subject_id not in results:
|
|
4024
|
+
results[subject_id] = {}
|
|
4025
|
+
if tract_name not in results[subject_id]:
|
|
4026
|
+
results[subject_id][tract_name] = {}
|
|
4027
|
+
# Break out of the loop immediately - pool is broken, can't process more
|
|
4028
|
+
break
|
|
4029
|
+
logger.exception("RuntimeError processing %s/%s", subject_id, tract_name)
|
|
4030
|
+
if subject_id not in results:
|
|
4031
|
+
results[subject_id] = {}
|
|
4032
|
+
if tract_name not in results[subject_id]:
|
|
4033
|
+
results[subject_id][tract_name] = {}
|
|
4034
|
+
except (OSError, ValueError) as e:
|
|
4035
|
+
# OSError catches system-level errors
|
|
4036
|
+
# ValueError catches other processing errors
|
|
4037
|
+
logger.exception("Error processing %s/%s: %s", subject_id, tract_name, type(e).__name__)
|
|
4038
|
+
if subject_id not in results:
|
|
4039
|
+
results[subject_id] = {}
|
|
4040
|
+
if tract_name not in results[subject_id]:
|
|
4041
|
+
results[subject_id][tract_name] = {}
|
|
4042
|
+
except Exception as e:
|
|
4043
|
+
# Catch any other unexpected exceptions, including BrokenProcessPool
|
|
4044
|
+
error_type = type(e).__name__
|
|
4045
|
+
error_msg = str(e)
|
|
4046
|
+
if (
|
|
4047
|
+
"BrokenProcessPool" in error_type
|
|
4048
|
+
or "BrokenProcessPool" in error_msg
|
|
4049
|
+
or "process pool" in error_msg.lower()
|
|
4050
|
+
or "was terminated abruptly" in error_msg
|
|
4051
|
+
):
|
|
4052
|
+
broken_pool_detected = True
|
|
4053
|
+
logger.exception(
|
|
4054
|
+
"BrokenProcessPool detected for %s/%s (caught as Exception). Worker process may have crashed",
|
|
4055
|
+
subject_id,
|
|
4056
|
+
tract_name,
|
|
4057
|
+
)
|
|
4058
|
+
# Mark this task as failed
|
|
4059
|
+
if subject_id not in results:
|
|
4060
|
+
results[subject_id] = {}
|
|
4061
|
+
if tract_name not in results[subject_id]:
|
|
4062
|
+
results[subject_id][tract_name] = {}
|
|
4063
|
+
# Break out of the loop immediately - pool is broken, can't process more
|
|
4064
|
+
break
|
|
4065
|
+
logger.exception(
|
|
4066
|
+
"Unexpected error processing %s/%s: %s",
|
|
4067
|
+
subject_id,
|
|
4068
|
+
tract_name,
|
|
4069
|
+
type(e).__name__,
|
|
4070
|
+
)
|
|
4071
|
+
if subject_id not in results:
|
|
4072
|
+
results[subject_id] = {}
|
|
4073
|
+
if tract_name not in results[subject_id]:
|
|
4074
|
+
results[subject_id][tract_name] = {}
|
|
4075
|
+
finally:
|
|
4076
|
+
# Increment counter regardless of success/failure
|
|
4077
|
+
completed_count += 1
|
|
4078
|
+
|
|
4079
|
+
# Clean up future reference and remove from futures dict
|
|
4080
|
+
# This is critical to prevent memory accumulation
|
|
4081
|
+
futures.pop(future, None)
|
|
4082
|
+
del future
|
|
4083
|
+
|
|
4084
|
+
# Periodic garbage collection every 10 completed tasks
|
|
4085
|
+
# This helps prevent memory buildup during long-running jobs
|
|
4086
|
+
if completed_count > 0 and completed_count % 10 == 0:
|
|
4087
|
+
gc.collect()
|
|
4088
|
+
except RuntimeError as e:
|
|
4089
|
+
# Catch BrokenProcessPool that might break out of the loop
|
|
4090
|
+
error_type = type(e).__name__
|
|
4091
|
+
error_msg = str(e)
|
|
4092
|
+
if (
|
|
4093
|
+
"BrokenProcessPool" in error_type
|
|
4094
|
+
or "BrokenProcessPool" in error_msg
|
|
4095
|
+
or "process pool" in error_msg.lower()
|
|
4096
|
+
or "was terminated abruptly" in error_msg
|
|
4097
|
+
):
|
|
4098
|
+
broken_pool_detected = True
|
|
4099
|
+
logger.exception(
|
|
4100
|
+
"BrokenProcessPool detected during result collection. Falling back to sequential processing",
|
|
4101
|
+
)
|
|
4102
|
+
else:
|
|
4103
|
+
# Re-raise if it's a different RuntimeError
|
|
4104
|
+
raise
|
|
4105
|
+
except Exception as e:
|
|
4106
|
+
# Catch any other exceptions that might break the loop
|
|
4107
|
+
error_type = type(e).__name__
|
|
4108
|
+
error_msg = str(e)
|
|
4109
|
+
if (
|
|
4110
|
+
"BrokenProcessPool" in error_type
|
|
4111
|
+
or "BrokenProcessPool" in error_msg
|
|
4112
|
+
or "process pool" in error_msg.lower()
|
|
4113
|
+
or "was terminated abruptly" in error_msg
|
|
4114
|
+
):
|
|
4115
|
+
broken_pool_detected = True
|
|
4116
|
+
logger.exception(
|
|
4117
|
+
"BrokenProcessPool detected during result collection (caught as Exception). Falling back to sequential processing",
|
|
4118
|
+
)
|
|
4119
|
+
else:
|
|
4120
|
+
# Log but don't re-raise - try to continue with what we have
|
|
4121
|
+
logger.exception("Unexpected error during result collection: %s", type(e).__name__)
|
|
4122
|
+
|
|
4123
|
+
# If we detected a broken process pool, break out and fall back to sequential
|
|
4124
|
+
if broken_pool_detected:
|
|
4125
|
+
# Raise error to trigger fallback to sequential processing
|
|
4126
|
+
# The exception will be caught by the outer try-except block
|
|
4127
|
+
raise RuntimeError("BrokenProcessPool detected, falling back to sequential processing") # noqa: TRY301
|
|
4128
|
+
|
|
4129
|
+
# Clear futures dictionary to free memory
|
|
4130
|
+
futures.clear()
|
|
4131
|
+
|
|
4132
|
+
# Force garbage collection after all parallel tasks complete
|
|
4133
|
+
# Run multiple times to handle circular references
|
|
4134
|
+
for _ in range(3):
|
|
4135
|
+
gc.collect()
|
|
4136
|
+
except RuntimeError as e:
|
|
4137
|
+
# RuntimeError catches BrokenProcessPool (which is a subclass of RuntimeError)
|
|
4138
|
+
# and other runtime errors that might occur with process pools
|
|
4139
|
+
error_msg = str(e)
|
|
4140
|
+
if "BrokenProcessPool" in error_msg or "falling back" in error_msg.lower():
|
|
4141
|
+
logger.warning(
|
|
4142
|
+
"Process pool error occurred (worker process may have crashed with core dump). "
|
|
4143
|
+
"Falling back to sequential processing. "
|
|
4144
|
+
"If this persists, try: (1) reducing n_jobs, (2) increasing memory allocation, "
|
|
4145
|
+
"or (3) using n_jobs=1 to force sequential processing.",
|
|
4146
|
+
)
|
|
4147
|
+
else:
|
|
4148
|
+
logger.exception("RuntimeError in process pool. Falling back to sequential processing")
|
|
4149
|
+
|
|
4150
|
+
# Fall back to sequential processing if process pool fails
|
|
4151
|
+
# Only process tasks that haven't been completed yet
|
|
4152
|
+
remaining_tasks = [
|
|
4153
|
+
(s_id, t_name, t_file, s_ref, t_out)
|
|
4154
|
+
for s_id, t_name, t_file, s_ref, t_out in tasks
|
|
4155
|
+
if s_id not in results or t_name not in results.get(s_id, {})
|
|
4156
|
+
]
|
|
4157
|
+
|
|
4158
|
+
if remaining_tasks:
|
|
4159
|
+
logger.info("Processing %d remaining tracts sequentially (fallback)", len(remaining_tasks))
|
|
4160
|
+
for subject_id, tract_name, tract_file, subject_ref_img, tract_output_dir in remaining_tasks:
|
|
4161
|
+
if subject_id not in results:
|
|
4162
|
+
results[subject_id] = {}
|
|
4163
|
+
# Skip if already processed
|
|
4164
|
+
if tract_name in results[subject_id]:
|
|
4165
|
+
continue
|
|
4166
|
+
try:
|
|
4167
|
+
tract_result = self._process_single_tract(
|
|
4168
|
+
subject_id=subject_id,
|
|
4169
|
+
tract_name=tract_name,
|
|
4170
|
+
tract_file=tract_file,
|
|
4171
|
+
subject_ref_img=subject_ref_img,
|
|
4172
|
+
tract_output_dir=tract_output_dir,
|
|
4173
|
+
subjects_mni_space=subjects_mni_space,
|
|
4174
|
+
atlas_files=atlas_files,
|
|
4175
|
+
metric_files=metric_files,
|
|
4176
|
+
atlas_ref_img=atlas_ref_img,
|
|
4177
|
+
flip_lr=flip_lr,
|
|
4178
|
+
skip_checks=skip_checks,
|
|
4179
|
+
**kwargs,
|
|
4180
|
+
)
|
|
4181
|
+
results[subject_id][tract_name] = tract_result
|
|
4182
|
+
# Clean up result reference
|
|
4183
|
+
del tract_result
|
|
4184
|
+
except (OSError, ValueError, RuntimeError, MemoryError) as process_error:
|
|
4185
|
+
logger.exception(
|
|
4186
|
+
"Error processing %s/%s in fallback mode (%s)",
|
|
4187
|
+
subject_id,
|
|
4188
|
+
tract_name,
|
|
4189
|
+
type(process_error).__name__,
|
|
4190
|
+
)
|
|
4191
|
+
results[subject_id][tract_name] = {}
|
|
4192
|
+
except Exception as process_error:
|
|
4193
|
+
logger.exception(
|
|
4194
|
+
"Unexpected error processing %s/%s in fallback mode (%s)",
|
|
4195
|
+
subject_id,
|
|
4196
|
+
tract_name,
|
|
4197
|
+
type(process_error).__name__,
|
|
4198
|
+
)
|
|
4199
|
+
results[subject_id][tract_name] = {}
|
|
4200
|
+
finally:
|
|
4201
|
+
# Force garbage collection after each tract in fallback mode
|
|
4202
|
+
# Run multiple times to handle circular references in VTK objects
|
|
4203
|
+
for _ in range(2):
|
|
4204
|
+
gc.collect()
|
|
4205
|
+
else:
|
|
4206
|
+
logger.info("All tasks already completed, skipping fallback processing")
|
|
4207
|
+
else:
|
|
4208
|
+
# Sequential processing
|
|
4209
|
+
logger.info("Processing %d tracts sequentially", len(tasks))
|
|
4210
|
+
for subject_id, tract_name, tract_file, subject_ref_img, tract_output_dir in tasks:
|
|
4211
|
+
if subject_id not in results:
|
|
4212
|
+
results[subject_id] = {}
|
|
4213
|
+
try:
|
|
4214
|
+
tract_result = self._process_single_tract(
|
|
4215
|
+
subject_id=subject_id,
|
|
4216
|
+
tract_name=tract_name,
|
|
4217
|
+
tract_file=tract_file,
|
|
4218
|
+
subject_ref_img=subject_ref_img,
|
|
4219
|
+
tract_output_dir=tract_output_dir,
|
|
4220
|
+
subjects_mni_space=subjects_mni_space,
|
|
4221
|
+
atlas_files=atlas_files,
|
|
4222
|
+
metric_files=metric_files,
|
|
4223
|
+
atlas_ref_img=atlas_ref_img,
|
|
4224
|
+
flip_lr=flip_lr,
|
|
4225
|
+
skip_checks=skip_checks,
|
|
4226
|
+
**kwargs,
|
|
4227
|
+
)
|
|
4228
|
+
results[subject_id][tract_name] = tract_result
|
|
4229
|
+
# Clean up result reference
|
|
4230
|
+
del tract_result
|
|
4231
|
+
finally:
|
|
4232
|
+
# Clean up after each tract in sequential processing
|
|
4233
|
+
# Run multiple times to handle circular references in VTK objects
|
|
4234
|
+
for _ in range(2):
|
|
4235
|
+
gc.collect()
|
|
4236
|
+
|
|
4237
|
+
# Generate HTML report
|
|
4238
|
+
html_output = output_dir / "quality_check_report.html" if html_output is None else Path(html_output)
|
|
4239
|
+
|
|
4240
|
+
# Convert Path objects to strings for HTML function
|
|
4241
|
+
results_for_html: dict[str, dict[str, dict[str, str]]] = {}
|
|
4242
|
+
for subject_id, subject_tracts in results.items():
|
|
4243
|
+
results_for_html[subject_id] = {}
|
|
4244
|
+
if isinstance(subject_tracts, dict):
|
|
4245
|
+
for tract_name, media_dict in subject_tracts.items():
|
|
4246
|
+
results_for_html[subject_id][tract_name] = {}
|
|
4247
|
+
if isinstance(media_dict, dict):
|
|
4248
|
+
for media_type, file_path in media_dict.items():
|
|
4249
|
+
# Convert Path to string, or keep as string/number
|
|
4250
|
+
if isinstance(file_path, Path):
|
|
4251
|
+
results_for_html[subject_id][tract_name][media_type] = str(file_path)
|
|
4252
|
+
else:
|
|
4253
|
+
results_for_html[subject_id][tract_name][media_type] = str(file_path)
|
|
4254
|
+
|
|
4255
|
+
try:
|
|
4256
|
+
create_quality_check_html(
|
|
4257
|
+
results_for_html,
|
|
4258
|
+
str(html_output),
|
|
4259
|
+
title="Tractography Quality Check Report",
|
|
4260
|
+
)
|
|
4261
|
+
logger.info("Quality check report generated: %s", html_output)
|
|
4262
|
+
except (OSError, ValueError, RuntimeError) as e:
|
|
4263
|
+
logger.warning("Failed to generate HTML report: %s", e)
|
|
4264
|
+
finally:
|
|
4265
|
+
# Clean up HTML conversion dictionary after use
|
|
4266
|
+
del results_for_html
|
|
4267
|
+
gc.collect()
|
|
4268
|
+
# Stop XVFB if it was started
|
|
4269
|
+
if vdisplay is not None:
|
|
4270
|
+
logger.info("Stopping XVFB")
|
|
4271
|
+
vdisplay.stop()
|
|
4272
|
+
return results
|