antspymm 1.5.5__tar.gz → 1.5.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {antspymm-1.5.5 → antspymm-1.5.6}/PKG-INFO +1 -1
- {antspymm-1.5.5 → antspymm-1.5.6}/antspymm/__init__.py +4 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/antspymm/mm.py +212 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/antspymm.egg-info/PKG-INFO +1 -1
- {antspymm-1.5.5 → antspymm-1.5.6}/antspymm.egg-info/SOURCES.txt +4 -1
- antspymm-1.5.6/docs/dti_distortion_correction_voxelwise_varying_bvectors_example_WIP.py +238 -0
- antspymm-1.5.6/docs/dti_reconstruction_voxelwise_varying_bvectors_localint.py +281 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/pyproject.toml +1 -1
- antspymm-1.5.6/tests/voxelwise_bvec_dti_recon_test.py +118 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/MANIFEST.in +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/README.md +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/antspymm.egg-info/dependency_links.txt +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/antspymm.egg-info/requires.txt +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/antspymm.egg-info/top_level.txt +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/adni_rsfmri_2_nrg_conversion.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/antspymm_annotated_output_tree.pages +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/antspymm_annotated_output_tree.txt +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/antspymm_data_dictionary.csv +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/aslprep_perfusion_run_localint.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/bids_2_nrg.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/bids_cohort_example.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/bind_mm_wide.R +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/blind_qc.Rmd +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/blind_qc.html +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/blind_qc.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/convert_adni_dti_to_nrg.R +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/deepnbm.jpg +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/deformation_gradient_reo.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/describe_mm_data.R +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/dipy_dti_recon.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/dti_recon.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/dti_reg.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/dwi_rebasing.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/dwi_run.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/dwi_run_ptbp_scrub.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/ex_rsfmri_run_minimal_ptbp.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/ex_sr.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/example_antspymm_output.csv +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/example_run_from_directory.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/flair_run_localint.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/joint_dti_recon_localint.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/make_dict_table.Rmd +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/make_dict_table.html +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/mm.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/mm_csv_ex_2.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/mm_csv_localint.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/mm_nrg.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/nrg_cohort_example.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/parallel_study_aggregation_example.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/perfusion_ptbp.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/perfusion_run_nnl.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/ppmi_step1_blind_qc.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/ppmi_step2_outlierness.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/ppmi_step3_mm_nrg_csv.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/ppmi_step4_aggregate.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/ptbp_nrg.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/roi_visualization.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/roi_visualization_ppmi.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/rsfmri_run_minimal_localint.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/run_local_integration_scripts.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/run_mm_example.sh +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/template_overlays.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/ukbb_rsfmri.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/ukbb_to_nrg_processing.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/ukbb_to_nrg_processing2.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/docs/visualize_tractogram.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/setup.cfg +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/tests/test_loop.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/tests/test_nrg_validation.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.6}/tests/test_reference_run.py +0 -0
@@ -65,6 +65,8 @@ from .mm import mc_denoise
|
|
65
65
|
from .mm import mc_reg
|
66
66
|
from .mm import dti_reg
|
67
67
|
from .mm import timeseries_reg
|
68
|
+
from .mm import timeseries_transform
|
69
|
+
from .mm import copy_spatial_metadata_from_3d_to_4d
|
68
70
|
from .mm import concat_dewarp
|
69
71
|
from .mm import mc_resample_image_to_target
|
70
72
|
from .mm import trim_dti_mask
|
@@ -136,5 +138,7 @@ from .mm import segment_timeseries_by_bvalue
|
|
136
138
|
from .mm import shorten_pymm_names
|
137
139
|
from .mm import pet3d_summary
|
138
140
|
from .mm import deformation_gradient_optimized
|
141
|
+
from .mm import efficient_dwi_fit_voxelwise
|
142
|
+
from .mm import generate_voxelwise_bvecs
|
139
143
|
|
140
144
|
|
@@ -1794,6 +1794,89 @@ def merge_timeseries_data( img_LR, img_RL, allow_resample=True ):
|
|
1794
1794
|
mimg.append( temp )
|
1795
1795
|
return ants.list_to_ndimage( img_LR, mimg )
|
1796
1796
|
|
1797
|
+
def copy_spatial_metadata_from_3d_to_4d(spatial_img, timeseries_img):
|
1798
|
+
"""
|
1799
|
+
Copy spatial metadata (origin, spacing, direction) from a 3D image to the
|
1800
|
+
spatial dimensions (first 3) of a 4D image, preserving the 4th dimension's metadata.
|
1801
|
+
|
1802
|
+
Parameters
|
1803
|
+
----------
|
1804
|
+
spatial_img : ants.ANTsImage
|
1805
|
+
A 3D ANTsImage with the desired spatial metadata.
|
1806
|
+
timeseries_img : ants.ANTsImage
|
1807
|
+
A 4D ANTsImage to update.
|
1808
|
+
|
1809
|
+
Returns
|
1810
|
+
-------
|
1811
|
+
ants.ANTsImage
|
1812
|
+
A 4D ANTsImage with updated spatial metadata.
|
1813
|
+
"""
|
1814
|
+
if spatial_img.dimension != 3:
|
1815
|
+
raise ValueError("spatial_img must be a 3D ANTsImage.")
|
1816
|
+
if timeseries_img.dimension != 4:
|
1817
|
+
raise ValueError("timeseries_img must be a 4D ANTsImage.")
|
1818
|
+
# Get 3D metadata
|
1819
|
+
spatial_origin = list(spatial_img.origin)
|
1820
|
+
spatial_spacing = list(spatial_img.spacing)
|
1821
|
+
spatial_direction = spatial_img.direction # 3x3
|
1822
|
+
# Get original 4D metadata
|
1823
|
+
ts_spacing = list(timeseries_img.spacing)
|
1824
|
+
ts_origin = list(timeseries_img.origin)
|
1825
|
+
ts_direction = timeseries_img.direction # 4x4
|
1826
|
+
# Replace only the first 3 entries for origin and spacing
|
1827
|
+
new_origin = spatial_origin + [ts_origin[3]]
|
1828
|
+
new_spacing = spatial_spacing + [ts_spacing[3]]
|
1829
|
+
# Replace top-left 3x3 block of direction matrix, preserve last row/column
|
1830
|
+
new_direction = ts_direction.copy()
|
1831
|
+
new_direction[:3, :3] = spatial_direction
|
1832
|
+
# Create updated image
|
1833
|
+
updated_img = ants.from_numpy(
|
1834
|
+
timeseries_img.numpy(),
|
1835
|
+
origin=new_origin,
|
1836
|
+
spacing=new_spacing,
|
1837
|
+
direction=new_direction
|
1838
|
+
)
|
1839
|
+
return updated_img
|
1840
|
+
|
1841
|
+
def timeseries_transform(transform, image, reference, interpolation='linear'):
|
1842
|
+
"""
|
1843
|
+
Apply a spatial transform to each 3D volume in a 4D time series image.
|
1844
|
+
|
1845
|
+
Parameters
|
1846
|
+
----------
|
1847
|
+
transform : ants transform object
|
1848
|
+
Path(s) to ANTs-compatible transform(s) to apply.
|
1849
|
+
image : ants.ANTsImage
|
1850
|
+
4D input image with shape (X, Y, Z, T).
|
1851
|
+
reference : ants.ANTsImage
|
1852
|
+
Reference image to match in space.
|
1853
|
+
interpolation : str
|
1854
|
+
Interpolation method: 'linear', 'nearestNeighbor', etc.
|
1855
|
+
|
1856
|
+
Returns
|
1857
|
+
-------
|
1858
|
+
ants.ANTsImage
|
1859
|
+
4D transformed image.
|
1860
|
+
"""
|
1861
|
+
if image.dimension != 4:
|
1862
|
+
raise ValueError("Input image must be 4D (X, Y, Z, T).")
|
1863
|
+
n_volumes = image.shape[3]
|
1864
|
+
transformed_volumes = []
|
1865
|
+
for t in range(n_volumes):
|
1866
|
+
vol = ants.slice_image( image, 3, t )
|
1867
|
+
transformed = ants.apply_ants_transform_to_image(
|
1868
|
+
transform=transform,
|
1869
|
+
image=vol,
|
1870
|
+
reference=reference,
|
1871
|
+
interpolation=interpolation
|
1872
|
+
)
|
1873
|
+
transformed_volumes.append(transformed.numpy())
|
1874
|
+
# Stack along time axis and convert to ANTsImage
|
1875
|
+
transformed_array = np.stack(transformed_volumes, axis=-1)
|
1876
|
+
out_image = ants.from_numpy(transformed_array)
|
1877
|
+
out_image = ants.copy_image_info(image, out_image)
|
1878
|
+
out_image = copy_spatial_metadata_from_3d_to_4d(reference, out_image)
|
1879
|
+
return out_image
|
1797
1880
|
|
1798
1881
|
def timeseries_reg(
|
1799
1882
|
image,
|
@@ -3932,6 +4015,135 @@ def efficient_dwi_fit(gtab, diffusion_model, imagein, maskin,
|
|
3932
4015
|
return full_fit, FA_img, MD_img, RGB_img
|
3933
4016
|
|
3934
4017
|
|
4018
|
+
def efficient_dwi_fit_voxelwise(imagein, maskin, bvals, bvecs_5d, model_params=None,
|
4019
|
+
bvals_to_use=None, num_threads=1, verbose=True):
|
4020
|
+
"""
|
4021
|
+
Voxel-wise diffusion model fitting with individual b-vectors per voxel.
|
4022
|
+
|
4023
|
+
Parameters
|
4024
|
+
----------
|
4025
|
+
imagein : ants.ANTsImage
|
4026
|
+
4D DWI image (X, Y, Z, N).
|
4027
|
+
maskin : ants.ANTsImage
|
4028
|
+
3D binary mask.
|
4029
|
+
bvals : (N,) array-like
|
4030
|
+
Common b-values across volumes.
|
4031
|
+
bvecs_5d : (X, Y, Z, N, 3) ndarray
|
4032
|
+
Voxel-specific b-vectors.
|
4033
|
+
model_params : dict
|
4034
|
+
Extra arguments for model.
|
4035
|
+
bvals_to_use : list[int]
|
4036
|
+
Subset of b-values to include.
|
4037
|
+
num_threads : int
|
4038
|
+
Number of threads to use.
|
4039
|
+
verbose : bool
|
4040
|
+
Whether to print status.
|
4041
|
+
|
4042
|
+
Returns
|
4043
|
+
-------
|
4044
|
+
FA_img : ants.ANTsImage
|
4045
|
+
Fractional anisotropy.
|
4046
|
+
MD_img : ants.ANTsImage
|
4047
|
+
Mean diffusivity.
|
4048
|
+
RGB_img : ants.ANTsImage
|
4049
|
+
RGB FA image.
|
4050
|
+
"""
|
4051
|
+
import numpy as np
|
4052
|
+
import ants
|
4053
|
+
import dipy.reconst.dti as dti
|
4054
|
+
from dipy.core.gradients import gradient_table
|
4055
|
+
from dipy.reconst.dti import fractional_anisotropy, color_fa, mean_diffusivity
|
4056
|
+
from concurrent.futures import ThreadPoolExecutor
|
4057
|
+
from tqdm import tqdm
|
4058
|
+
|
4059
|
+
model_params = model_params or {}
|
4060
|
+
img = imagein.numpy()
|
4061
|
+
mask = maskin.numpy().astype(bool)
|
4062
|
+
X, Y, Z, N = img.shape
|
4063
|
+
|
4064
|
+
if bvals_to_use is not None:
|
4065
|
+
sel = np.isin(bvals, bvals_to_use)
|
4066
|
+
img = img[..., sel]
|
4067
|
+
bvals = bvals[sel]
|
4068
|
+
bvecs_5d = bvecs_5d[..., sel, :]
|
4069
|
+
|
4070
|
+
FA = np.zeros((X, Y, Z), dtype=np.float32)
|
4071
|
+
MD = np.zeros((X, Y, Z), dtype=np.float32)
|
4072
|
+
RGB = np.zeros((X, Y, Z, 3), dtype=np.float32)
|
4073
|
+
|
4074
|
+
def fit_voxel(ix, iy, iz):
|
4075
|
+
if not mask[ix, iy, iz]:
|
4076
|
+
return
|
4077
|
+
sig = img[ix, iy, iz, :]
|
4078
|
+
if np.all(sig == 0):
|
4079
|
+
return
|
4080
|
+
bv = bvecs_5d[ix, iy, iz, :, :]
|
4081
|
+
gtab = gradient_table(bvals, bv)
|
4082
|
+
try:
|
4083
|
+
model = dti.TensorModel(gtab, **model_params)
|
4084
|
+
fit = model.fit(sig)
|
4085
|
+
evals = fit.evals
|
4086
|
+
evecs = fit.evecs
|
4087
|
+
FA[ix, iy, iz] = fractional_anisotropy(evals)
|
4088
|
+
MD[ix, iy, iz] = mean_diffusivity(evals)
|
4089
|
+
RGB[ix, iy, iz, :] = color_fa(FA[ix, iy, iz], evecs)
|
4090
|
+
except Exception as e:
|
4091
|
+
if verbose:
|
4092
|
+
print(f"Voxel ({ix},{iy},{iz}) fit failed: {e}")
|
4093
|
+
|
4094
|
+
coords = np.argwhere(mask)
|
4095
|
+
if verbose:
|
4096
|
+
print(f"[INFO] Fitting {len(coords)} voxels using {num_threads} threads...")
|
4097
|
+
|
4098
|
+
if num_threads > 1:
|
4099
|
+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
4100
|
+
list(tqdm(executor.map(lambda c: fit_voxel(*c), coords), total=len(coords)))
|
4101
|
+
else:
|
4102
|
+
for c in tqdm(coords):
|
4103
|
+
fit_voxel(*c)
|
4104
|
+
|
4105
|
+
ref = ants.slice_image(imagein, axis=3, idx=0)
|
4106
|
+
return (
|
4107
|
+
ants.copy_image_info(ref, ants.from_numpy(FA)),
|
4108
|
+
ants.copy_image_info(ref, ants.from_numpy(MD)),
|
4109
|
+
ants.merge_channels([ants.copy_image_info(ref, ants.from_numpy(RGB[..., i])) for i in range(3)])
|
4110
|
+
)
|
4111
|
+
|
4112
|
+
|
4113
|
+
def generate_voxelwise_bvecs(global_bvecs, voxel_rotations, transpose=False):
|
4114
|
+
"""
|
4115
|
+
Generate voxel-wise b-vectors from a global bvec and voxel-wise rotation field.
|
4116
|
+
|
4117
|
+
Parameters
|
4118
|
+
----------
|
4119
|
+
global_bvecs : ndarray of shape (N, 3)
|
4120
|
+
Global diffusion gradient directions.
|
4121
|
+
voxel_rotations : ndarray of shape (X, Y, Z, 3, 3)
|
4122
|
+
3x3 rotation matrix for each voxel (can come from Jacobian of deformation field).
|
4123
|
+
transpose : bool, optional
|
4124
|
+
If True, transpose the rotation matrices before applying them to the b-vectors.
|
4125
|
+
|
4126
|
+
|
4127
|
+
Returns
|
4128
|
+
-------
|
4129
|
+
bvecs_5d : ndarray of shape (X, Y, Z, N, 3)
|
4130
|
+
Voxel-specific b-vectors.
|
4131
|
+
"""
|
4132
|
+
X, Y, Z, _, _ = voxel_rotations.shape
|
4133
|
+
N = global_bvecs.shape[0]
|
4134
|
+
bvecs_5d = np.zeros((X, Y, Z, N, 3), dtype=np.float32)
|
4135
|
+
|
4136
|
+
for n in range(N):
|
4137
|
+
bvec = global_bvecs[n]
|
4138
|
+
for i in range(X):
|
4139
|
+
for j in range(Y):
|
4140
|
+
for k in range(Z):
|
4141
|
+
R = voxel_rotations[i, j, k]
|
4142
|
+
if transpose:
|
4143
|
+
R = R.T # Use transpose if needed
|
4144
|
+
bvecs_5d[i, j, k, n, :] = R @ bvec
|
4145
|
+
return bvecs_5d
|
4146
|
+
|
3935
4147
|
def dipy_dti_recon(
|
3936
4148
|
image,
|
3937
4149
|
bvalsfn,
|
@@ -24,7 +24,9 @@ docs/deepnbm.jpg
|
|
24
24
|
docs/deformation_gradient_reo.py
|
25
25
|
docs/describe_mm_data.R
|
26
26
|
docs/dipy_dti_recon.py
|
27
|
+
docs/dti_distortion_correction_voxelwise_varying_bvectors_example_WIP.py
|
27
28
|
docs/dti_recon.py
|
29
|
+
docs/dti_reconstruction_voxelwise_varying_bvectors_localint.py
|
28
30
|
docs/dti_reg.py
|
29
31
|
docs/dwi_rebasing.py
|
30
32
|
docs/dwi_run.py
|
@@ -62,4 +64,5 @@ docs/ukbb_to_nrg_processing2.py
|
|
62
64
|
docs/visualize_tractogram.py
|
63
65
|
tests/test_loop.py
|
64
66
|
tests/test_nrg_validation.py
|
65
|
-
tests/test_reference_run.py
|
67
|
+
tests/test_reference_run.py
|
68
|
+
tests/voxelwise_bvec_dti_recon_test.py
|
@@ -0,0 +1,238 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import ants
|
3
|
+
import os
|
4
|
+
from dipy.io.gradients import read_bvals_bvecs
|
5
|
+
from scipy.stats import pearsonr
|
6
|
+
import antspymm
|
7
|
+
nt = 2
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
from scipy.stats import pearsonr
|
11
|
+
|
12
|
+
def read_bvecs_rotated(bvec_file, rotmat):
|
13
|
+
bvecs = np.loadtxt(bvec_file)
|
14
|
+
if bvecs.shape[0] != 3:
|
15
|
+
bvecs = bvecs.T
|
16
|
+
rotated_bvecs = (rotmat @ bvecs).T
|
17
|
+
return rotated_bvecs
|
18
|
+
|
19
|
+
def mean_rgb_correlation(img1, img2, mask):
|
20
|
+
"""
|
21
|
+
Compute the mean correlation between two RGB images.
|
22
|
+
|
23
|
+
Parameters
|
24
|
+
----------
|
25
|
+
img1 : np.ndarray
|
26
|
+
First RGB image as a (H, W, 3) NumPy array.
|
27
|
+
img2 : np.ndarray
|
28
|
+
Second RGB image as a (H, W, 3) NumPy array.
|
29
|
+
|
30
|
+
Returns
|
31
|
+
-------
|
32
|
+
float
|
33
|
+
Mean Pearson correlation across the three RGB channels.
|
34
|
+
"""
|
35
|
+
if img1.shape != img2.shape:
|
36
|
+
raise ValueError("Input images must have the same shape.")
|
37
|
+
correlations = []
|
38
|
+
img1c = ants.split_channels(img1)
|
39
|
+
img2c = ants.split_channels(img2)
|
40
|
+
for c in range(3): # R, G, B
|
41
|
+
x = extract_masked_values( img1c[c], mask)
|
42
|
+
y = extract_masked_values( img2c[c], mask)
|
43
|
+
if np.std(x) == 0 or np.std(y) == 0:
|
44
|
+
corr = 0.0 # Handle flat images
|
45
|
+
else:
|
46
|
+
corr, _ = pearsonr(x, y)
|
47
|
+
correlations.append(corr)
|
48
|
+
return np.mean(correlations)
|
49
|
+
|
50
|
+
import numpy as np
|
51
|
+
import ants
|
52
|
+
|
53
|
+
def mean_rgb_mae(img1, img2, mask):
|
54
|
+
"""
|
55
|
+
Compute the mean absolute error (MAE) between two RGB images.
|
56
|
+
|
57
|
+
Parameters
|
58
|
+
----------
|
59
|
+
img1 : np.ndarray
|
60
|
+
First RGB image as a (H, W, 3) NumPy array.
|
61
|
+
img2 : np.ndarray
|
62
|
+
Second RGB image as a (H, W, 3) NumPy array.
|
63
|
+
mask : ants.ANTsImage
|
64
|
+
Binary mask defining valid pixels for error calculation.
|
65
|
+
|
66
|
+
Returns
|
67
|
+
-------
|
68
|
+
float
|
69
|
+
Mean absolute error across the three RGB channels.
|
70
|
+
"""
|
71
|
+
if img1.shape != img2.shape:
|
72
|
+
raise ValueError("Input images must have the same shape.")
|
73
|
+
|
74
|
+
mae_values = []
|
75
|
+
img1c = ants.split_channels(img1)
|
76
|
+
img2c = ants.split_channels(img2)
|
77
|
+
|
78
|
+
for c in range(3): # R, G, B
|
79
|
+
x = extract_masked_values(img1c[c], mask)
|
80
|
+
y = extract_masked_values(img2c[c], mask)
|
81
|
+
mae = np.mean(np.abs(x - y))
|
82
|
+
mae_values.append(mae)
|
83
|
+
|
84
|
+
return np.mean(mae_values)
|
85
|
+
|
86
|
+
def extract_masked_values(image, mask):
|
87
|
+
return image.numpy()[mask.numpy() > 0]
|
88
|
+
|
89
|
+
import numpy as np
|
90
|
+
|
91
|
+
def verify_unit_bvecs(bvecs, tol=1e-5):
|
92
|
+
"""
|
93
|
+
Verifies that each b-vector has unit norm within a tolerance.
|
94
|
+
|
95
|
+
Parameters
|
96
|
+
----------
|
97
|
+
bvecs : array-like, shape (N, 3)
|
98
|
+
Array of b-vectors, one per diffusion direction.
|
99
|
+
tol : float
|
100
|
+
Tolerance for unit norm.
|
101
|
+
|
102
|
+
Returns
|
103
|
+
-------
|
104
|
+
is_unit : np.ndarray, shape (N,)
|
105
|
+
Boolean array indicating if each b-vector is unit norm.
|
106
|
+
norms : np.ndarray, shape (N,)
|
107
|
+
Norms of each b-vector.
|
108
|
+
"""
|
109
|
+
bvecs = np.asarray(bvecs)
|
110
|
+
norms = np.linalg.norm(bvecs, axis=1)
|
111
|
+
is_unit = np.abs(norms - 1) < tol
|
112
|
+
return is_unit, norms
|
113
|
+
|
114
|
+
# def test_efficient_dwi_fit_voxelwise_distortion_correction():
|
115
|
+
if True:
|
116
|
+
print("simple test for distortion_correction consistency of efficient_dwi_fit_voxelwise")
|
117
|
+
ex_path = os.path.expanduser( "~/.antspyt1w/" )
|
118
|
+
ex_path_mm = os.path.expanduser( "~/.antspymm/" )
|
119
|
+
#### Load in data ####
|
120
|
+
print("Load in subject data ...")
|
121
|
+
lrid = "I1499279_Anon_20210819142214_5"
|
122
|
+
rlid = "I1499337_Anon_20210819142214_6"
|
123
|
+
# Load paths
|
124
|
+
print("📁 Loading subject LR data...")
|
125
|
+
lrid = os.path.join(ex_path_mm, lrid )
|
126
|
+
img_LR_in = ants.image_read(lrid + '.nii.gz')
|
127
|
+
img_LR_in_avg = ants.get_average_of_timeseries( img_LR_in )
|
128
|
+
mask = img_LR_in_avg.get_mask()
|
129
|
+
bvals, bvecs = read_bvals_bvecs(lrid + '.bval', lrid + '.bvec')
|
130
|
+
bvecs = np.asarray(bvecs)
|
131
|
+
shape = img_LR_in.shape[:3]
|
132
|
+
|
133
|
+
print("📁 Loading subject RL data...")
|
134
|
+
rlid = os.path.join(ex_path_mm, rlid )
|
135
|
+
img_RL_in = ants.image_read(rlid + '.nii.gz')
|
136
|
+
bvalsRL, bvecsRL = read_bvals_bvecs(rlid + '.bval', rlid + '.bvec')
|
137
|
+
|
138
|
+
img_RL_in_avg = ants.get_average_of_timeseries( img_RL_in )
|
139
|
+
maskRL = img_RL_in_avg.get_mask()
|
140
|
+
|
141
|
+
print("🧠 Running baseline LR fit...")
|
142
|
+
bvecs_5d_orig = np.broadcast_to(bvecs, shape + bvecs.shape).copy()
|
143
|
+
if not "FA_orig" in globals():
|
144
|
+
FA_orig, MD_orig, RGB_orig = antspymm.efficient_dwi_fit_voxelwise(
|
145
|
+
imagein=img_LR_in,
|
146
|
+
maskin=mask,
|
147
|
+
bvals=bvals,
|
148
|
+
bvecs_5d=bvecs_5d_orig,
|
149
|
+
model_params={},
|
150
|
+
bvals_to_use=None,
|
151
|
+
num_threads=nt,
|
152
|
+
verbose=False
|
153
|
+
)
|
154
|
+
|
155
|
+
bvecs_5d_origRL = np.broadcast_to(np.asarray(bvecsRL), shape + bvecsRL.shape).copy()
|
156
|
+
if not "FA_origRL" in globals():
|
157
|
+
FA_origRL, MD_origRL, RGB_origRL = antspymm.efficient_dwi_fit_voxelwise(
|
158
|
+
imagein=img_RL_in,
|
159
|
+
maskin=maskRL,
|
160
|
+
bvals=bvalsRL,
|
161
|
+
bvecs_5d=bvecs_5d_origRL,
|
162
|
+
model_params={},
|
163
|
+
bvals_to_use=None,
|
164
|
+
num_threads=nt,
|
165
|
+
verbose=False
|
166
|
+
)
|
167
|
+
|
168
|
+
print("dist corr")
|
169
|
+
if not "mytx" in globals():
|
170
|
+
mytx = ants.registration( FA_orig, FA_origRL, 'SyNBold' )
|
171
|
+
mytx2 = ants.apply_transforms(FA_orig, FA_origRL, mytx['fwdtransforms'],
|
172
|
+
interpolator='linear', imagetype=0, compose='/tmp/comptx' )
|
173
|
+
print( mytx2 )
|
174
|
+
|
175
|
+
print("🔄 now with distortion correction...")
|
176
|
+
mydef = ants.image_read( mytx2 )
|
177
|
+
mywarp = ants.transform_from_displacement_field( mydef )
|
178
|
+
img_w = antspymm.timeseries_transform(mywarp, img_RL_in, reference=img_LR_in_avg)
|
179
|
+
mask_w = ants.apply_ants_transform_to_image(mywarp, maskRL, reference=img_LR_in_avg, interpolation='nearestNeighbor')
|
180
|
+
print("🧠 Running warped fit...")
|
181
|
+
if not "FA_w" in globals():
|
182
|
+
bvecsRL = np.asarray(bvecsRL)
|
183
|
+
mydefgrad = antspymm.deformation_gradient_optimized( mydef,
|
184
|
+
to_rotation=False, to_inverse_rotation=True )
|
185
|
+
bvecsRLw = antspymm.generate_voxelwise_bvecs( bvecsRL, mydefgrad, transpose=False )
|
186
|
+
FA_w, MD_w, RGB_w = antspymm.efficient_dwi_fit_voxelwise(
|
187
|
+
imagein=img_w,
|
188
|
+
maskin=mask_w,
|
189
|
+
bvals=bvalsRL,
|
190
|
+
bvecs_5d=bvecsRLw,
|
191
|
+
model_params={},
|
192
|
+
bvals_to_use=None,
|
193
|
+
num_threads=nt,
|
194
|
+
verbose=False
|
195
|
+
)
|
196
|
+
|
197
|
+
if not "FA_w2" in globals():
|
198
|
+
bvecsRL = np.asarray(bvecsRL)
|
199
|
+
FA_w2, MD_w2, RGB_w2 = antspymm.efficient_dwi_fit_voxelwise(
|
200
|
+
imagein=img_w,
|
201
|
+
maskin=mask_w,
|
202
|
+
bvals=bvalsRL,
|
203
|
+
bvecs_5d=bvecs_5d_origRL,
|
204
|
+
model_params={},
|
205
|
+
bvals_to_use=None,
|
206
|
+
num_threads=nt,
|
207
|
+
verbose=False
|
208
|
+
)
|
209
|
+
|
210
|
+
fff = mean_rgb_correlation
|
211
|
+
print("📊 Comparing results...")
|
212
|
+
maskJoined = ants.threshold_image( mask + mask_w, 1.05, 2.0 )
|
213
|
+
maske=ants.iMath(maskJoined,'ME',3)
|
214
|
+
fa_corr = fff( RGB_orig, RGB_w, maske )
|
215
|
+
print(f"✅ FA correlation (original vs distortion corrected): {fa_corr:.4f}")
|
216
|
+
|
217
|
+
fa_corrX = fff( RGB_orig, RGB_w2, maske )
|
218
|
+
print(f"✅ FA correlation (original vs distortion corrected global recon): {fa_corrX:.4f}")
|
219
|
+
|
220
|
+
RGB_origRLc = ants.split_channels(RGB_origRL)
|
221
|
+
for c in range(3):
|
222
|
+
RGB_origRLc[c] = ants.apply_ants_transform_to_image(
|
223
|
+
mywarp, RGB_origRLc[c],
|
224
|
+
reference=img_LR_in_avg, interpolation='linear'
|
225
|
+
)
|
226
|
+
|
227
|
+
fa_corrY = fff( RGB_orig, ants.merge_channels(RGB_origRLc), maske )
|
228
|
+
print(f"✅ FA correlation (original vs warped RGB global recon): {fa_corrY:.4f}")
|
229
|
+
|
230
|
+
# assert fa_corr > 0.80, "FA correlation too low"
|
231
|
+
|
232
|
+
print("🎉 Test passed: model is distortion-consistent.")
|
233
|
+
|
234
|
+
# ants.image_write( FA_orig, '/tmp/xxx.nii.gz' )
|
235
|
+
# ants.image_write( FA_w, '/tmp/yyy.nii.gz' )
|
236
|
+
#
|
237
|
+
# Example usage:
|
238
|
+
# test_efficient_dwi_fit_voxelwise_distortion_correction()
|
@@ -0,0 +1,281 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import ants
|
3
|
+
import os
|
4
|
+
from dipy.io.gradients import read_bvals_bvecs
|
5
|
+
from scipy.stats import pearsonr
|
6
|
+
import antspymm
|
7
|
+
import matplotlib.pyplot as plt
|
8
|
+
nt = 2
|
9
|
+
# amount or rotation around x, y, z-axis
|
10
|
+
degrotx = 15
|
11
|
+
degroty = 15
|
12
|
+
degrotz = 15
|
13
|
+
|
14
|
+
|
15
|
+
def mean_rgb_correlation(img1, img2, mask):
|
16
|
+
"""
|
17
|
+
Compute the mean correlation between two RGB images.
|
18
|
+
|
19
|
+
Parameters
|
20
|
+
----------
|
21
|
+
img1 : np.ndarray
|
22
|
+
First RGB image as a (H, W, 3) NumPy array.
|
23
|
+
img2 : np.ndarray
|
24
|
+
Second RGB image as a (H, W, 3) NumPy array.
|
25
|
+
|
26
|
+
Returns
|
27
|
+
-------
|
28
|
+
float
|
29
|
+
Mean Pearson correlation across the three RGB channels.
|
30
|
+
"""
|
31
|
+
if img1.shape != img2.shape:
|
32
|
+
raise ValueError("Input images must have the same shape.")
|
33
|
+
correlations = []
|
34
|
+
img1c = ants.split_channels(img1)
|
35
|
+
img2c = ants.split_channels(img2)
|
36
|
+
for c in range(3): # R, G, B
|
37
|
+
x = extract_masked_values( img1c[c], mask)
|
38
|
+
y = extract_masked_values( img2c[c], mask)
|
39
|
+
if np.std(x) == 0 or np.std(y) == 0:
|
40
|
+
corr = 0.0 # Handle flat images
|
41
|
+
else:
|
42
|
+
corr, _ = pearsonr(x, y)
|
43
|
+
correlations.append(corr)
|
44
|
+
return np.mean(correlations)
|
45
|
+
|
46
|
+
|
47
|
+
def plot_correlation(x, y, xlabel="Image 1", ylabel="Image 2", title_prefix="Correlation", point_alpha=0.3):
|
48
|
+
"""
|
49
|
+
Plot the correlation between two 1D arrays and display Pearson r.
|
50
|
+
|
51
|
+
Parameters
|
52
|
+
----------
|
53
|
+
x : array-like
|
54
|
+
First set of values.
|
55
|
+
y : array-like
|
56
|
+
Second set of values.
|
57
|
+
xlabel : str
|
58
|
+
Label for the x-axis.
|
59
|
+
ylabel : str
|
60
|
+
Label for the y-axis.
|
61
|
+
title_prefix : str
|
62
|
+
Title prefix, will append r value.
|
63
|
+
point_alpha : float
|
64
|
+
Transparency of scatter plot points.
|
65
|
+
|
66
|
+
Returns
|
67
|
+
-------
|
68
|
+
float
|
69
|
+
Pearson correlation coefficient (r).
|
70
|
+
"""
|
71
|
+
x = np.asarray(x)
|
72
|
+
y = np.asarray(y)
|
73
|
+
|
74
|
+
if x.shape != y.shape:
|
75
|
+
raise ValueError("Inputs must have the same shape.")
|
76
|
+
|
77
|
+
r, _ = pearsonr(x, y)
|
78
|
+
|
79
|
+
plt.figure(figsize=(6, 6))
|
80
|
+
plt.scatter(x, y, s=1, alpha=point_alpha, color='blue')
|
81
|
+
plt.xlabel(xlabel)
|
82
|
+
plt.ylabel(ylabel)
|
83
|
+
plt.title(f"{title_prefix} (r = {r:.3f})")
|
84
|
+
plt.plot([x.min(), x.max()], [x.min(), x.max()], 'r--', linewidth=1) # Identity line
|
85
|
+
plt.grid(True)
|
86
|
+
plt.axis('equal')
|
87
|
+
plt.tight_layout()
|
88
|
+
plt.show()
|
89
|
+
|
90
|
+
return r
|
91
|
+
|
92
|
+
def read_bvecs_rotated(bvec_file, rotmat):
|
93
|
+
bvecs = np.loadtxt(bvec_file)
|
94
|
+
if bvecs.shape[0] != 3:
|
95
|
+
bvecs = bvecs.T
|
96
|
+
rotated_bvecs = (rotmat @ bvecs).T
|
97
|
+
return rotated_bvecs
|
98
|
+
|
99
|
+
def broadcast_bvecs_voxelwise(rotated_bvecs, shape):
|
100
|
+
return np.broadcast_to(rotated_bvecs, shape + rotated_bvecs.shape).copy()
|
101
|
+
|
102
|
+
def extract_masked_values(image, mask):
|
103
|
+
return image.numpy()[mask.numpy() > 0]
|
104
|
+
|
105
|
+
# def test_efficient_dwi_fit_voxelwise_rotation_consistency():
|
106
|
+
if True:
|
107
|
+
print("simple test for rotation consistency of efficient_dwi_fit_voxelwise")
|
108
|
+
ex_path = os.path.expanduser( "~/.antspyt1w/" )
|
109
|
+
ex_path_mm = os.path.expanduser( "~/.antspymm/" )
|
110
|
+
JHU_atlas = ants.image_read( ex_path + 'JHU-ICBM-FA-1mm.nii.gz' ) # Read in JHU atlas
|
111
|
+
JHU_labels = ants.image_read( ex_path + 'JHU-ICBM-labels-1mm.nii.gz' ) # Read in JHU labels
|
112
|
+
#### Load in data ####
|
113
|
+
print("Load in subject data ...")
|
114
|
+
lrid = ex_path_mm + "I1499279_Anon_20210819142214_5"
|
115
|
+
rlid = ex_path_mm + "I1499337_Anon_20210819142214_6"
|
116
|
+
t1id = ex_path_mm + "t1_rand.nii.gz"
|
117
|
+
# Load paths
|
118
|
+
print("📁 Loading subject data...")
|
119
|
+
lrid = os.path.join(ex_path_mm, "I1499279_Anon_20210819142214_5")
|
120
|
+
img_LR_in = ants.image_read(lrid + '.nii.gz')
|
121
|
+
img_LR_in_avg = ants.get_average_of_timeseries( img_LR_in )
|
122
|
+
mask = img_LR_in_avg.get_mask()
|
123
|
+
|
124
|
+
bvals, bvecs = read_bvals_bvecs(lrid + '.bval', lrid + '.bvec')
|
125
|
+
bvecs = np.asarray(bvecs)
|
126
|
+
shape = img_LR_in.shape[:3]
|
127
|
+
|
128
|
+
print("🧠 Running baseline (unrotated) fit...")
|
129
|
+
bvecs_5d_orig = np.broadcast_to(bvecs, shape + bvecs.shape).copy()
|
130
|
+
if not "FA_orig" in globals():
|
131
|
+
FA_orig, MD_orig, RGB_orig = antspymm.efficient_dwi_fit_voxelwise(
|
132
|
+
imagein=img_LR_in,
|
133
|
+
maskin=mask,
|
134
|
+
bvals=bvals,
|
135
|
+
bvecs_5d=bvecs_5d_orig,
|
136
|
+
model_params={},
|
137
|
+
bvals_to_use=None,
|
138
|
+
num_threads=nt,
|
139
|
+
verbose=False
|
140
|
+
)
|
141
|
+
|
142
|
+
print("🔄 Applying known rotation...")
|
143
|
+
# maxtrans = 10.0
|
144
|
+
# rotator = ants.contrib.RandomRotate3D((-maxtrans, maxtrans), reference=img_LR_in_avg)
|
145
|
+
rotator=ants.contrib.Rotate3D(rotation=(degrotx,degroty,degrotz), reference=img_LR_in_avg )
|
146
|
+
rotation = rotator.transform()
|
147
|
+
img_rotated_ref = ants.apply_ants_transform_to_image(rotation, img_LR_in_avg, reference=img_LR_in_avg)
|
148
|
+
img_rotated = antspymm.timeseries_transform(rotation, img_LR_in, reference=img_LR_in_avg)
|
149
|
+
mask_rotated = ants.apply_ants_transform_to_image(rotation, mask, reference=img_LR_in_avg, interpolation='nearestNeighbor')
|
150
|
+
# note: we apply the inverse rotation to the bvecs
|
151
|
+
# i.e. if we register A to B and get R then apply R_inv to bvecs
|
152
|
+
rotmat = ants.get_ants_transform_parameters(rotation.invert()).reshape((4, 3))[:3, :3]
|
153
|
+
bvecs_rotated = read_bvecs_rotated(lrid + '.bvec', rotmat)
|
154
|
+
bvecs_5d_rot = broadcast_bvecs_voxelwise(bvecs_rotated, shape)
|
155
|
+
|
156
|
+
print("🧠 Running rotated fit...")
|
157
|
+
if not "FA_rot" in globals():
|
158
|
+
FA_rot, MD_rot, RGB_rot = antspymm.efficient_dwi_fit_voxelwise(
|
159
|
+
imagein=img_rotated,
|
160
|
+
maskin=mask_rotated,
|
161
|
+
bvals=bvals,
|
162
|
+
bvecs_5d=bvecs_5d_rot,
|
163
|
+
model_params={},
|
164
|
+
bvals_to_use=None,
|
165
|
+
num_threads=nt,
|
166
|
+
verbose=False
|
167
|
+
)
|
168
|
+
|
169
|
+
FA_rot_back = ants.apply_ants_transform_to_image(
|
170
|
+
rotation.invert(),
|
171
|
+
FA_rot,
|
172
|
+
reference = img_LR_in_avg,
|
173
|
+
)
|
174
|
+
MD_rot_back = ants.apply_ants_transform_to_image(
|
175
|
+
rotation.invert(),
|
176
|
+
MD_rot,
|
177
|
+
reference = img_LR_in_avg,
|
178
|
+
)
|
179
|
+
print("📊 Comparing results...")
|
180
|
+
maske=ants.iMath(mask,'ME',2)
|
181
|
+
# smoothing simulates double interpolation
|
182
|
+
FA_origs = ants.smooth_image(FA_orig, 1.25)
|
183
|
+
MD_origs = ants.smooth_image(MD_orig, 1.25)
|
184
|
+
fa1 = extract_masked_values(FA_origs, maske)
|
185
|
+
fa2 = extract_masked_values(FA_rot_back, maske)
|
186
|
+
md1 = extract_masked_values(MD_origs, maske)
|
187
|
+
md2 = extract_masked_values(MD_rot_back, maske)
|
188
|
+
|
189
|
+
fa_corr, _ = pearsonr(fa1, fa2)
|
190
|
+
md_corr, _ = pearsonr(md1, md2)
|
191
|
+
|
192
|
+
# plot_correlation( fa1, fa2 )
|
193
|
+
|
194
|
+
print(f"✅ FA correlation (original vs rotated): {fa_corr:.4f}")
|
195
|
+
print(f"✅ MD correlation (original vs rotated): {md_corr:.4f}")
|
196
|
+
|
197
|
+
assert fa_corr > 0.80, "FA correlation too low"
|
198
|
+
assert md_corr > 0.80, "MD correlation too low"
|
199
|
+
|
200
|
+
print("🎉 Test passed: model is rotation-consistent with voxelwise bvecs.")
|
201
|
+
|
202
|
+
print("This shows that the simulation is effective.")
|
203
|
+
print("Now use the simulated data to test the distortion correction...")
|
204
|
+
# ants.image_write( RGB_orig, '/tmp/temp0rgb.nii.gz' )
|
205
|
+
# ants.image_write( RGB_rot, '/tmp/temp1rgb.nii.gz' )
|
206
|
+
|
207
|
+
|
208
|
+
# now we have to map the img_rotated and its bvec_rotated partner
|
209
|
+
# as we would with a generic distortion correction framework.
|
210
|
+
# first --- write the transform to a file
|
211
|
+
rotationinv = rotation.invert()
|
212
|
+
|
213
|
+
import tempfile
|
214
|
+
with tempfile.TemporaryDirectory() as tempdir:
|
215
|
+
rotmatfile = os.path.join(tempdir, "rotation.mat")
|
216
|
+
compositefile = os.path.join(tempdir, "composite")
|
217
|
+
|
218
|
+
# Write the rotation transform
|
219
|
+
ants.write_transform(rotationinv, rotmatfile)
|
220
|
+
|
221
|
+
# Run registration if not already done
|
222
|
+
if "reg" not in globals():
|
223
|
+
reg = ants.registration(FA_orig, FA_rot, 'SyN', initial_transform=rotmatfile)
|
224
|
+
|
225
|
+
# Apply the transformation using a temporary composite path
|
226
|
+
comptx = ants.apply_transforms(
|
227
|
+
img_LR_in_avg,
|
228
|
+
img_rotated_ref,
|
229
|
+
reg['fwdtransforms'],
|
230
|
+
interpolator='linear',
|
231
|
+
compose=compositefile,
|
232
|
+
verbose=True
|
233
|
+
)
|
234
|
+
mydef=ants.image_read(comptx)
|
235
|
+
mydefgrad = antspymm.deformation_gradient_optimized( mydef,
|
236
|
+
to_rotation=False, to_inverse_rotation=True )
|
237
|
+
bvecsRLw = antspymm.generate_voxelwise_bvecs( bvecs_rotated, mydefgrad, transpose=False )
|
238
|
+
mywarp = ants.transform_from_displacement_field( mydef )
|
239
|
+
img_rotated_avg = ants.get_average_of_timeseries( img_rotated )
|
240
|
+
img_w = antspymm.timeseries_transform(mywarp, img_rotated, reference=img_rotated_avg )
|
241
|
+
|
242
|
+
correlations = []
|
243
|
+
labels = ["NoBvecReo", "BvecReo"]
|
244
|
+
|
245
|
+
RGB_origs = ants.smooth_image(RGB_orig, 1.0)
|
246
|
+
|
247
|
+
for label, bv in zip(labels, [bvecs_5d_rot, bvecsRLw]):
|
248
|
+
FA_w, MD_w, RGB_w = antspymm.efficient_dwi_fit_voxelwise(
|
249
|
+
imagein=img_w,
|
250
|
+
maskin=ants.get_mask(ants.get_average_of_timeseries(img_w)),
|
251
|
+
bvals=bvals,
|
252
|
+
bvecs_5d=bv,
|
253
|
+
model_params={},
|
254
|
+
bvals_to_use=None,
|
255
|
+
num_threads=nt,
|
256
|
+
verbose=False
|
257
|
+
)
|
258
|
+
# Optional: write RGB images to temporary files for inspection
|
259
|
+
rgb_orig_path = os.path.join(tempdir, "RGB_orig.nii.gz")
|
260
|
+
rgb_warped_path = os.path.join(tempdir, f"RGB_warped_{label}.nii.gz")
|
261
|
+
|
262
|
+
ants.image_write(RGB_origs, rgb_orig_path)
|
263
|
+
ants.image_write(RGB_w, rgb_warped_path)
|
264
|
+
|
265
|
+
print(f"📝 Saved RGB_orig to: {rgb_orig_path}")
|
266
|
+
print(f"📝 Saved RGB_warped ({label}) to: {rgb_warped_path}")
|
267
|
+
|
268
|
+
compmask = ants.get_mask(ants.get_average_of_timeseries(img_w))
|
269
|
+
compmask = ants.iMath(compmask, 'ME', 2)
|
270
|
+
|
271
|
+
corr = mean_rgb_correlation(RGB_origs, RGB_w, compmask)
|
272
|
+
print(f"{label:10s} correlation: {corr:.4f}")
|
273
|
+
correlations.append(corr)
|
274
|
+
|
275
|
+
# Compare the two correlations
|
276
|
+
if len(correlations) == 2:
|
277
|
+
print("\nComparison Result:")
|
278
|
+
if correlations[1] > correlations[0]:
|
279
|
+
print("✅ BvecReo gives higher correlation than NoBvecReo.")
|
280
|
+
else:
|
281
|
+
print("❌ NoBvecReo gives equal or higher correlation than BvecReo.")
|
@@ -0,0 +1,118 @@
|
|
1
|
+
import numpy as np
|
2
|
+
from scipy.spatial.transform import Rotation as R
|
3
|
+
import antspymm
|
4
|
+
|
5
|
+
def test_generate_voxelwise_bvecs():
|
6
|
+
# Define synthetic bvecs (N = 3 directions)
|
7
|
+
bvecs = np.array([
|
8
|
+
[1, 0, 0], # x-axis
|
9
|
+
[0, 1, 0], # y-axis
|
10
|
+
[0, 0, 1], # z-axis
|
11
|
+
])
|
12
|
+
|
13
|
+
# Define image shape (small synthetic 4D volume)
|
14
|
+
X, Y, Z = 2, 2, 2
|
15
|
+
N = bvecs.shape[0]
|
16
|
+
|
17
|
+
# Define a known rotation (e.g., 90-degree rotation around z-axis)
|
18
|
+
rot = R.from_euler('z', 33, degrees=True).as_matrix() # shape (3, 3)
|
19
|
+
|
20
|
+
# Create a voxelwise rotation field (X, Y, Z, 3, 3), same rot everywhere
|
21
|
+
voxel_rotations = np.tile(rot, (X, Y, Z, 1, 1))
|
22
|
+
voxel_rotations = voxel_rotations.reshape(X, Y, Z, 3, 3)
|
23
|
+
|
24
|
+
# Expected rotated bvecs
|
25
|
+
expected = np.dot(voxel_rotations[0, 0, 0], bvecs.T).T # (N, 3)
|
26
|
+
|
27
|
+
# Call the function under test
|
28
|
+
bvecs_5d = antspymm.generate_voxelwise_bvecs(bvecs, voxel_rotations) # shape (X, Y, Z, N, 3)
|
29
|
+
|
30
|
+
# Check that all voxel outputs match expected result
|
31
|
+
for i in range(X):
|
32
|
+
for j in range(Y):
|
33
|
+
for k in range(Z):
|
34
|
+
actual = bvecs_5d[i, j, k]
|
35
|
+
assert np.allclose(actual, expected, atol=1e-6), \
|
36
|
+
f"Mismatch at voxel {(i, j, k)}: {actual} vs {expected}"
|
37
|
+
|
38
|
+
print("✅ test_generate_voxelwise_bvecs passed!")
|
39
|
+
|
40
|
+
# Run the test
|
41
|
+
print("Running test for generate_voxelwise_bvecs...")
|
42
|
+
test_generate_voxelwise_bvecs()
|
43
|
+
|
44
|
+
|
45
|
+
import numpy as np
|
46
|
+
import ants
|
47
|
+
from scipy.spatial.transform import Rotation as R
|
48
|
+
|
49
|
+
def generate_dummy_dwi_data(shape, n_volumes):
|
50
|
+
"""
|
51
|
+
Generate synthetic DWI-like 4D data and a brain mask.
|
52
|
+
"""
|
53
|
+
np.random.seed(42)
|
54
|
+
dwi = np.random.rand(*shape, n_volumes).astype(np.float32)
|
55
|
+
mask = np.ones(shape, dtype=np.uint8)
|
56
|
+
return dwi, mask
|
57
|
+
|
58
|
+
def create_voxelwise_bvecs(shape, bvecs, rotation_matrix=None):
|
59
|
+
"""
|
60
|
+
Create voxelwise (5D) bvecs array (X, Y, Z, N, 3), optionally rotated.
|
61
|
+
"""
|
62
|
+
if rotation_matrix is not None:
|
63
|
+
rotated_bvecs = (rotation_matrix @ bvecs.T).T
|
64
|
+
else:
|
65
|
+
rotated_bvecs = bvecs
|
66
|
+
|
67
|
+
# Broadcast to 5D shape
|
68
|
+
bvecs_5d = np.broadcast_to(rotated_bvecs, shape + rotated_bvecs.shape)
|
69
|
+
return bvecs_5d.copy()
|
70
|
+
|
71
|
+
def test_efficient_dwi_fit_voxelwise():
|
72
|
+
|
73
|
+
# Parameters
|
74
|
+
shape = (3, 3, 3)
|
75
|
+
n_vols = 6
|
76
|
+
|
77
|
+
# Synthetic data
|
78
|
+
dwi_data, mask_data = generate_dummy_dwi_data(shape, n_vols)
|
79
|
+
ants_dwi = ants.from_numpy(dwi_data)
|
80
|
+
ants_mask = ants.from_numpy(mask_data)
|
81
|
+
|
82
|
+
# Define bvals and bvecs
|
83
|
+
bvals = np.array([0, 1000, 1000, 1000, 1000, 1000])
|
84
|
+
bvecs = np.array([
|
85
|
+
[0, 0, 0],
|
86
|
+
[1, 0, 0],
|
87
|
+
[0, 1, 0],
|
88
|
+
[0, 0, 1],
|
89
|
+
[1/np.sqrt(2), 1/np.sqrt(2), 0],
|
90
|
+
[1/np.sqrt(2), 0, 1/np.sqrt(2)],
|
91
|
+
])
|
92
|
+
|
93
|
+
# Optional: apply rotation
|
94
|
+
rotation_matrix = R.from_euler('z', 45, degrees=True).as_matrix()
|
95
|
+
bvecs_5d = create_voxelwise_bvecs(shape, bvecs, rotation_matrix)
|
96
|
+
|
97
|
+
# Call function under test
|
98
|
+
FA_img, MD_img, RGB_img = antspymm.efficient_dwi_fit_voxelwise(
|
99
|
+
imagein=ants_dwi,
|
100
|
+
maskin=ants_mask,
|
101
|
+
bvals=bvals,
|
102
|
+
bvecs_5d=bvecs_5d,
|
103
|
+
model_params={},
|
104
|
+
bvals_to_use=None,
|
105
|
+
num_threads=1,
|
106
|
+
verbose=False
|
107
|
+
)
|
108
|
+
|
109
|
+
# Tests
|
110
|
+
assert isinstance(FA_img, ants.ANTsImage), "FA_img should be an ANTsImage"
|
111
|
+
assert FA_img.shape == shape, f"FA image has shape {FA_img.shape}, expected {shape}"
|
112
|
+
assert np.all((FA_img.numpy() >= 0) & (FA_img.numpy() <= 1)), "FA values should be in [0, 1]"
|
113
|
+
assert MD_img.shape == shape, "MD image has incorrect shape"
|
114
|
+
|
115
|
+
print("✅ test_efficient_dwi_fit_voxelwise passed!")
|
116
|
+
|
117
|
+
# Run the test
|
118
|
+
test_efficient_dwi_fit_voxelwise()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|