antspymm 1.5.5__tar.gz → 1.5.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {antspymm-1.5.5 → antspymm-1.5.7}/PKG-INFO +1 -1
- {antspymm-1.5.5 → antspymm-1.5.7}/antspymm/__init__.py +6 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/antspymm/mm.py +246 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/antspymm.egg-info/PKG-INFO +1 -1
- {antspymm-1.5.5 → antspymm-1.5.7}/antspymm.egg-info/SOURCES.txt +4 -1
- antspymm-1.5.7/docs/dti_distortion_correction_voxelwise_varying_bvectors_example.py +353 -0
- antspymm-1.5.7/docs/dti_reconstruction_voxelwise_varying_bvectors_localint.py +282 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/pyproject.toml +1 -1
- antspymm-1.5.7/tests/voxelwise_bvec_dti_recon_test.py +118 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/MANIFEST.in +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/README.md +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/antspymm.egg-info/dependency_links.txt +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/antspymm.egg-info/requires.txt +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/antspymm.egg-info/top_level.txt +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/adni_rsfmri_2_nrg_conversion.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/antspymm_annotated_output_tree.pages +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/antspymm_annotated_output_tree.txt +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/antspymm_data_dictionary.csv +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/aslprep_perfusion_run_localint.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/bids_2_nrg.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/bids_cohort_example.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/bind_mm_wide.R +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/blind_qc.Rmd +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/blind_qc.html +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/blind_qc.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/convert_adni_dti_to_nrg.R +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/deepnbm.jpg +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/deformation_gradient_reo.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/describe_mm_data.R +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/dipy_dti_recon.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/dti_recon.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/dti_reg.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/dwi_rebasing.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/dwi_run.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/dwi_run_ptbp_scrub.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/ex_rsfmri_run_minimal_ptbp.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/ex_sr.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/example_antspymm_output.csv +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/example_run_from_directory.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/flair_run_localint.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/joint_dti_recon_localint.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/make_dict_table.Rmd +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/make_dict_table.html +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/mm.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/mm_csv_ex_2.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/mm_csv_localint.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/mm_nrg.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/nrg_cohort_example.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/parallel_study_aggregation_example.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/perfusion_ptbp.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/perfusion_run_nnl.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/ppmi_step1_blind_qc.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/ppmi_step2_outlierness.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/ppmi_step3_mm_nrg_csv.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/ppmi_step4_aggregate.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/ptbp_nrg.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/roi_visualization.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/roi_visualization_ppmi.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/rsfmri_run_minimal_localint.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/run_local_integration_scripts.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/run_mm_example.sh +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/template_overlays.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/ukbb_rsfmri.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/ukbb_to_nrg_processing.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/ukbb_to_nrg_processing2.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/docs/visualize_tractogram.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/setup.cfg +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/tests/test_loop.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/tests/test_nrg_validation.py +0 -0
- {antspymm-1.5.5 → antspymm-1.5.7}/tests/test_reference_run.py +0 -0
@@ -65,6 +65,8 @@ from .mm import mc_denoise
|
|
65
65
|
from .mm import mc_reg
|
66
66
|
from .mm import dti_reg
|
67
67
|
from .mm import timeseries_reg
|
68
|
+
from .mm import timeseries_transform
|
69
|
+
from .mm import copy_spatial_metadata_from_3d_to_4d
|
68
70
|
from .mm import concat_dewarp
|
69
71
|
from .mm import mc_resample_image_to_target
|
70
72
|
from .mm import trim_dti_mask
|
@@ -136,5 +138,9 @@ from .mm import segment_timeseries_by_bvalue
|
|
136
138
|
from .mm import shorten_pymm_names
|
137
139
|
from .mm import pet3d_summary
|
138
140
|
from .mm import deformation_gradient_optimized
|
141
|
+
from .mm import efficient_dwi_fit_voxelwise
|
142
|
+
from .mm import generate_voxelwise_bvecs
|
143
|
+
from .mm import distortion_correct_bvecs
|
144
|
+
|
139
145
|
|
140
146
|
|
@@ -1794,6 +1794,89 @@ def merge_timeseries_data( img_LR, img_RL, allow_resample=True ):
|
|
1794
1794
|
mimg.append( temp )
|
1795
1795
|
return ants.list_to_ndimage( img_LR, mimg )
|
1796
1796
|
|
1797
|
+
def copy_spatial_metadata_from_3d_to_4d(spatial_img, timeseries_img):
|
1798
|
+
"""
|
1799
|
+
Copy spatial metadata (origin, spacing, direction) from a 3D image to the
|
1800
|
+
spatial dimensions (first 3) of a 4D image, preserving the 4th dimension's metadata.
|
1801
|
+
|
1802
|
+
Parameters
|
1803
|
+
----------
|
1804
|
+
spatial_img : ants.ANTsImage
|
1805
|
+
A 3D ANTsImage with the desired spatial metadata.
|
1806
|
+
timeseries_img : ants.ANTsImage
|
1807
|
+
A 4D ANTsImage to update.
|
1808
|
+
|
1809
|
+
Returns
|
1810
|
+
-------
|
1811
|
+
ants.ANTsImage
|
1812
|
+
A 4D ANTsImage with updated spatial metadata.
|
1813
|
+
"""
|
1814
|
+
if spatial_img.dimension != 3:
|
1815
|
+
raise ValueError("spatial_img must be a 3D ANTsImage.")
|
1816
|
+
if timeseries_img.dimension != 4:
|
1817
|
+
raise ValueError("timeseries_img must be a 4D ANTsImage.")
|
1818
|
+
# Get 3D metadata
|
1819
|
+
spatial_origin = list(spatial_img.origin)
|
1820
|
+
spatial_spacing = list(spatial_img.spacing)
|
1821
|
+
spatial_direction = spatial_img.direction # 3x3
|
1822
|
+
# Get original 4D metadata
|
1823
|
+
ts_spacing = list(timeseries_img.spacing)
|
1824
|
+
ts_origin = list(timeseries_img.origin)
|
1825
|
+
ts_direction = timeseries_img.direction # 4x4
|
1826
|
+
# Replace only the first 3 entries for origin and spacing
|
1827
|
+
new_origin = spatial_origin + [ts_origin[3]]
|
1828
|
+
new_spacing = spatial_spacing + [ts_spacing[3]]
|
1829
|
+
# Replace top-left 3x3 block of direction matrix, preserve last row/column
|
1830
|
+
new_direction = ts_direction.copy()
|
1831
|
+
new_direction[:3, :3] = spatial_direction
|
1832
|
+
# Create updated image
|
1833
|
+
updated_img = ants.from_numpy(
|
1834
|
+
timeseries_img.numpy(),
|
1835
|
+
origin=new_origin,
|
1836
|
+
spacing=new_spacing,
|
1837
|
+
direction=new_direction
|
1838
|
+
)
|
1839
|
+
return updated_img
|
1840
|
+
|
1841
|
+
def timeseries_transform(transform, image, reference, interpolation='linear'):
|
1842
|
+
"""
|
1843
|
+
Apply a spatial transform to each 3D volume in a 4D time series image.
|
1844
|
+
|
1845
|
+
Parameters
|
1846
|
+
----------
|
1847
|
+
transform : ants transform object
|
1848
|
+
Path(s) to ANTs-compatible transform(s) to apply.
|
1849
|
+
image : ants.ANTsImage
|
1850
|
+
4D input image with shape (X, Y, Z, T).
|
1851
|
+
reference : ants.ANTsImage
|
1852
|
+
Reference image to match in space.
|
1853
|
+
interpolation : str
|
1854
|
+
Interpolation method: 'linear', 'nearestNeighbor', etc.
|
1855
|
+
|
1856
|
+
Returns
|
1857
|
+
-------
|
1858
|
+
ants.ANTsImage
|
1859
|
+
4D transformed image.
|
1860
|
+
"""
|
1861
|
+
if image.dimension != 4:
|
1862
|
+
raise ValueError("Input image must be 4D (X, Y, Z, T).")
|
1863
|
+
n_volumes = image.shape[3]
|
1864
|
+
transformed_volumes = []
|
1865
|
+
for t in range(n_volumes):
|
1866
|
+
vol = ants.slice_image( image, 3, t )
|
1867
|
+
transformed = ants.apply_ants_transform_to_image(
|
1868
|
+
transform=transform,
|
1869
|
+
image=vol,
|
1870
|
+
reference=reference,
|
1871
|
+
interpolation=interpolation
|
1872
|
+
)
|
1873
|
+
transformed_volumes.append(transformed.numpy())
|
1874
|
+
# Stack along time axis and convert to ANTsImage
|
1875
|
+
transformed_array = np.stack(transformed_volumes, axis=-1)
|
1876
|
+
out_image = ants.from_numpy(transformed_array)
|
1877
|
+
out_image = ants.copy_image_info(image, out_image)
|
1878
|
+
out_image = copy_spatial_metadata_from_3d_to_4d(reference, out_image)
|
1879
|
+
return out_image
|
1797
1880
|
|
1798
1881
|
def timeseries_reg(
|
1799
1882
|
image,
|
@@ -2055,6 +2138,40 @@ def bvec_reorientation( motion_parameters, bvecs, rebase=None ):
|
|
2055
2138
|
bvecs[myidx,:] = np.dot( rebase, bvecs[myidx,:] )
|
2056
2139
|
return bvecs
|
2057
2140
|
|
2141
|
+
|
2142
|
+
def distortion_correct_bvecs(bvecs, def_grad, A_img, A_ref):
|
2143
|
+
"""
|
2144
|
+
Vectorized computation of voxel-wise distortion corrected b-vectors.
|
2145
|
+
|
2146
|
+
Parameters
|
2147
|
+
----------
|
2148
|
+
bvecs : ndarray (N, 3)
|
2149
|
+
def_grad : ndarray (X, Y, Z, 3, 3) containing rotations derived from the deformation gradient
|
2150
|
+
A_img : ndarray (3, 3) direction matrix of the fixed image (target undistorted space)
|
2151
|
+
A_ref : ndarray (3, 3) direction matrix of the moving image (being corrected)
|
2152
|
+
|
2153
|
+
Returns
|
2154
|
+
-------
|
2155
|
+
bvecs_5d : ndarray (X, Y, Z, N, 3)
|
2156
|
+
"""
|
2157
|
+
X, Y, Z = def_grad.shape[:3]
|
2158
|
+
N = bvecs.shape[0]
|
2159
|
+
# Combined rotation: R_voxel = A_ref.T @ A_img @ def_grad
|
2160
|
+
A = A_ref.T @ A_img
|
2161
|
+
R_voxel = np.einsum('ij,xyzjk->xyzik', A, def_grad) # (X, Y, Z, 3, 3)
|
2162
|
+
# Apply R_voxel.T @ bvecs
|
2163
|
+
# First, reshape R_voxel: (X*Y*Z, 3, 3)
|
2164
|
+
R_voxel_reshaped = R_voxel.reshape(-1, 3, 3)
|
2165
|
+
# Rotate all bvecs for each voxel
|
2166
|
+
# Output: (X*Y*Z, N, 3)
|
2167
|
+
rotated = np.einsum('vij,nj->vni', R_voxel_reshaped, bvecs)
|
2168
|
+
# Normalize
|
2169
|
+
norms = np.linalg.norm(rotated, axis=2, keepdims=True)
|
2170
|
+
rotated /= np.clip(norms, 1e-8, None)
|
2171
|
+
# Reshape back to (X, Y, Z, N, 3)
|
2172
|
+
bvecs_5d = rotated.reshape(X, Y, Z, N, 3)
|
2173
|
+
return bvecs_5d
|
2174
|
+
|
2058
2175
|
def get_dti( reference_image, tensormodel, upper_triangular=True, return_image=False ):
|
2059
2176
|
"""
|
2060
2177
|
extract DTI data from a dipy tensormodel
|
@@ -3932,6 +4049,135 @@ def efficient_dwi_fit(gtab, diffusion_model, imagein, maskin,
|
|
3932
4049
|
return full_fit, FA_img, MD_img, RGB_img
|
3933
4050
|
|
3934
4051
|
|
4052
|
+
def efficient_dwi_fit_voxelwise(imagein, maskin, bvals, bvecs_5d, model_params=None,
|
4053
|
+
bvals_to_use=None, num_threads=1, verbose=True):
|
4054
|
+
"""
|
4055
|
+
Voxel-wise diffusion model fitting with individual b-vectors per voxel.
|
4056
|
+
|
4057
|
+
Parameters
|
4058
|
+
----------
|
4059
|
+
imagein : ants.ANTsImage
|
4060
|
+
4D DWI image (X, Y, Z, N).
|
4061
|
+
maskin : ants.ANTsImage
|
4062
|
+
3D binary mask.
|
4063
|
+
bvals : (N,) array-like
|
4064
|
+
Common b-values across volumes.
|
4065
|
+
bvecs_5d : (X, Y, Z, N, 3) ndarray
|
4066
|
+
Voxel-specific b-vectors.
|
4067
|
+
model_params : dict
|
4068
|
+
Extra arguments for model.
|
4069
|
+
bvals_to_use : list[int]
|
4070
|
+
Subset of b-values to include.
|
4071
|
+
num_threads : int
|
4072
|
+
Number of threads to use.
|
4073
|
+
verbose : bool
|
4074
|
+
Whether to print status.
|
4075
|
+
|
4076
|
+
Returns
|
4077
|
+
-------
|
4078
|
+
FA_img : ants.ANTsImage
|
4079
|
+
Fractional anisotropy.
|
4080
|
+
MD_img : ants.ANTsImage
|
4081
|
+
Mean diffusivity.
|
4082
|
+
RGB_img : ants.ANTsImage
|
4083
|
+
RGB FA image.
|
4084
|
+
"""
|
4085
|
+
import numpy as np
|
4086
|
+
import ants
|
4087
|
+
import dipy.reconst.dti as dti
|
4088
|
+
from dipy.core.gradients import gradient_table
|
4089
|
+
from dipy.reconst.dti import fractional_anisotropy, color_fa, mean_diffusivity
|
4090
|
+
from concurrent.futures import ThreadPoolExecutor
|
4091
|
+
from tqdm import tqdm
|
4092
|
+
|
4093
|
+
model_params = model_params or {}
|
4094
|
+
img = imagein.numpy()
|
4095
|
+
mask = maskin.numpy().astype(bool)
|
4096
|
+
X, Y, Z, N = img.shape
|
4097
|
+
|
4098
|
+
if bvals_to_use is not None:
|
4099
|
+
sel = np.isin(bvals, bvals_to_use)
|
4100
|
+
img = img[..., sel]
|
4101
|
+
bvals = bvals[sel]
|
4102
|
+
bvecs_5d = bvecs_5d[..., sel, :]
|
4103
|
+
|
4104
|
+
FA = np.zeros((X, Y, Z), dtype=np.float32)
|
4105
|
+
MD = np.zeros((X, Y, Z), dtype=np.float32)
|
4106
|
+
RGB = np.zeros((X, Y, Z, 3), dtype=np.float32)
|
4107
|
+
|
4108
|
+
def fit_voxel(ix, iy, iz):
|
4109
|
+
if not mask[ix, iy, iz]:
|
4110
|
+
return
|
4111
|
+
sig = img[ix, iy, iz, :]
|
4112
|
+
if np.all(sig == 0):
|
4113
|
+
return
|
4114
|
+
bv = bvecs_5d[ix, iy, iz, :, :]
|
4115
|
+
gtab = gradient_table(bvals, bv)
|
4116
|
+
try:
|
4117
|
+
model = dti.TensorModel(gtab, **model_params)
|
4118
|
+
fit = model.fit(sig)
|
4119
|
+
evals = fit.evals
|
4120
|
+
evecs = fit.evecs
|
4121
|
+
FA[ix, iy, iz] = fractional_anisotropy(evals)
|
4122
|
+
MD[ix, iy, iz] = mean_diffusivity(evals)
|
4123
|
+
RGB[ix, iy, iz, :] = color_fa(FA[ix, iy, iz], evecs)
|
4124
|
+
except Exception as e:
|
4125
|
+
if verbose:
|
4126
|
+
print(f"Voxel ({ix},{iy},{iz}) fit failed: {e}")
|
4127
|
+
|
4128
|
+
coords = np.argwhere(mask)
|
4129
|
+
if verbose:
|
4130
|
+
print(f"[INFO] Fitting {len(coords)} voxels using {num_threads} threads...")
|
4131
|
+
|
4132
|
+
if num_threads > 1:
|
4133
|
+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
4134
|
+
list(tqdm(executor.map(lambda c: fit_voxel(*c), coords), total=len(coords)))
|
4135
|
+
else:
|
4136
|
+
for c in tqdm(coords):
|
4137
|
+
fit_voxel(*c)
|
4138
|
+
|
4139
|
+
ref = ants.slice_image(imagein, axis=3, idx=0)
|
4140
|
+
return (
|
4141
|
+
ants.copy_image_info(ref, ants.from_numpy(FA)),
|
4142
|
+
ants.copy_image_info(ref, ants.from_numpy(MD)),
|
4143
|
+
ants.merge_channels([ants.copy_image_info(ref, ants.from_numpy(RGB[..., i])) for i in range(3)])
|
4144
|
+
)
|
4145
|
+
|
4146
|
+
|
4147
|
+
def generate_voxelwise_bvecs(global_bvecs, voxel_rotations, transpose=False):
|
4148
|
+
"""
|
4149
|
+
Generate voxel-wise b-vectors from a global bvec and voxel-wise rotation field.
|
4150
|
+
|
4151
|
+
Parameters
|
4152
|
+
----------
|
4153
|
+
global_bvecs : ndarray of shape (N, 3)
|
4154
|
+
Global diffusion gradient directions.
|
4155
|
+
voxel_rotations : ndarray of shape (X, Y, Z, 3, 3)
|
4156
|
+
3x3 rotation matrix for each voxel (can come from Jacobian of deformation field).
|
4157
|
+
transpose : bool, optional
|
4158
|
+
If True, transpose the rotation matrices before applying them to the b-vectors.
|
4159
|
+
|
4160
|
+
|
4161
|
+
Returns
|
4162
|
+
-------
|
4163
|
+
bvecs_5d : ndarray of shape (X, Y, Z, N, 3)
|
4164
|
+
Voxel-specific b-vectors.
|
4165
|
+
"""
|
4166
|
+
X, Y, Z, _, _ = voxel_rotations.shape
|
4167
|
+
N = global_bvecs.shape[0]
|
4168
|
+
bvecs_5d = np.zeros((X, Y, Z, N, 3), dtype=np.float32)
|
4169
|
+
|
4170
|
+
for n in range(N):
|
4171
|
+
bvec = global_bvecs[n]
|
4172
|
+
for i in range(X):
|
4173
|
+
for j in range(Y):
|
4174
|
+
for k in range(Z):
|
4175
|
+
R = voxel_rotations[i, j, k]
|
4176
|
+
if transpose:
|
4177
|
+
R = R.T # Use transpose if needed
|
4178
|
+
bvecs_5d[i, j, k, n, :] = R @ bvec
|
4179
|
+
return bvecs_5d
|
4180
|
+
|
3935
4181
|
def dipy_dti_recon(
|
3936
4182
|
image,
|
3937
4183
|
bvalsfn,
|
@@ -24,7 +24,9 @@ docs/deepnbm.jpg
|
|
24
24
|
docs/deformation_gradient_reo.py
|
25
25
|
docs/describe_mm_data.R
|
26
26
|
docs/dipy_dti_recon.py
|
27
|
+
docs/dti_distortion_correction_voxelwise_varying_bvectors_example.py
|
27
28
|
docs/dti_recon.py
|
29
|
+
docs/dti_reconstruction_voxelwise_varying_bvectors_localint.py
|
28
30
|
docs/dti_reg.py
|
29
31
|
docs/dwi_rebasing.py
|
30
32
|
docs/dwi_run.py
|
@@ -62,4 +64,5 @@ docs/ukbb_to_nrg_processing2.py
|
|
62
64
|
docs/visualize_tractogram.py
|
63
65
|
tests/test_loop.py
|
64
66
|
tests/test_nrg_validation.py
|
65
|
-
tests/test_reference_run.py
|
67
|
+
tests/test_reference_run.py
|
68
|
+
tests/voxelwise_bvec_dti_recon_test.py
|
@@ -0,0 +1,353 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import ants
|
3
|
+
import os
|
4
|
+
from dipy.io.gradients import read_bvals_bvecs
|
5
|
+
from scipy.stats import pearsonr
|
6
|
+
import antspymm
|
7
|
+
import numpy as np
|
8
|
+
from scipy.stats import pearsonr
|
9
|
+
nt=8
|
10
|
+
|
11
|
+
##################################################################
|
12
|
+
# for easier to access data with a full mm_csv example, see:
|
13
|
+
# github.com:stnava/ANTPD_antspymm
|
14
|
+
##################################################################
|
15
|
+
from os.path import exists
|
16
|
+
import os
|
17
|
+
import signal
|
18
|
+
import urllib.request
|
19
|
+
import zipfile
|
20
|
+
import tempfile
|
21
|
+
from pathlib import Path
|
22
|
+
from tqdm import tqdm
|
23
|
+
import antspynet
|
24
|
+
|
25
|
+
REQUIRED_FILES = [
|
26
|
+
"PPMI/101018/20210412/T1w/1496225/PPMI-101018-20210412-T1w-1496225.nii.gz",
|
27
|
+
"PPMI/101018/20210412/DTI_LR/1496234/PPMI-101018-20210412-DTI_LR-1496234.nii.gz"
|
28
|
+
]
|
29
|
+
|
30
|
+
def broadcast_bvecs_voxelwise(rotated_bvecs, shape):
|
31
|
+
return np.broadcast_to(rotated_bvecs, shape + rotated_bvecs.shape).copy()
|
32
|
+
|
33
|
+
|
34
|
+
def _validate_required_files(base_dir, required_files):
|
35
|
+
for rel_path in required_files:
|
36
|
+
full_path = os.path.join(base_dir, rel_path)
|
37
|
+
if not os.path.isfile(full_path):
|
38
|
+
print(f"❌ Missing required file: {rel_path}")
|
39
|
+
return False
|
40
|
+
return True
|
41
|
+
|
42
|
+
def _download_with_progress(url, destination):
|
43
|
+
with urllib.request.urlopen(url) as response, open(destination, 'wb') as out_file:
|
44
|
+
total = int(response.getheader('Content-Length', 0))
|
45
|
+
with tqdm(total=total, unit='B', unit_scale=True, desc="Downloading", ncols=80) as pbar:
|
46
|
+
while True:
|
47
|
+
chunk = response.read(8192)
|
48
|
+
if not chunk:
|
49
|
+
break
|
50
|
+
out_file.write(chunk)
|
51
|
+
pbar.update(len(chunk))
|
52
|
+
|
53
|
+
def find_data_dir(candidate_paths=None, max_tries=5, timeout=22, allow_download=None, required_files=REQUIRED_FILES):
|
54
|
+
"""
|
55
|
+
Attempts to locate or download the ANTsPyMM testing dataset.
|
56
|
+
|
57
|
+
Parameters
|
58
|
+
----------
|
59
|
+
candidate_paths : list of str or None
|
60
|
+
Directories to search for the data. If None, uses sensible defaults.
|
61
|
+
max_tries : int
|
62
|
+
Number of chances to enter a valid path manually.
|
63
|
+
timeout : int
|
64
|
+
Seconds to wait for user input before timing out.
|
65
|
+
allow_download : None | str
|
66
|
+
If not None, will download to {allow_download}/nrgdata_test if needed.
|
67
|
+
required_files : list of str
|
68
|
+
Relative paths that must exist inside the data directory.
|
69
|
+
|
70
|
+
Returns
|
71
|
+
-------
|
72
|
+
str
|
73
|
+
Path to a valid data directory.
|
74
|
+
"""
|
75
|
+
if candidate_paths is None:
|
76
|
+
candidate_paths = [
|
77
|
+
"~/Downloads/temp/shortrun/nrgdata_test",
|
78
|
+
"~/Downloads/ANTsPyMM_testing_data/nrgdata_test",
|
79
|
+
"~/data/ppmi/nrgdata_test",
|
80
|
+
"/mnt/data/nrgdata_test"
|
81
|
+
]
|
82
|
+
|
83
|
+
# First, search known paths
|
84
|
+
for path in candidate_paths:
|
85
|
+
full_path = os.path.expanduser(path)
|
86
|
+
if os.path.isdir(full_path) and _validate_required_files(full_path, required_files):
|
87
|
+
print(f"✅ Found valid data directory: {full_path}")
|
88
|
+
return full_path
|
89
|
+
|
90
|
+
# Handle automatic download
|
91
|
+
if isinstance(allow_download, str):
|
92
|
+
base_dir = os.path.expanduser(allow_download)
|
93
|
+
target_dir = os.path.join(base_dir, "nrgdata_test")
|
94
|
+
if not os.path.isdir(target_dir) or not _validate_required_files(target_dir, required_files):
|
95
|
+
print(f"📥 Will download data to: {target_dir}")
|
96
|
+
url = "https://figshare.com/ndownloader/articles/29391236/versions/1"
|
97
|
+
os.makedirs(base_dir, exist_ok=True)
|
98
|
+
zip_path = os.path.join(tempfile.gettempdir(), "antspymm_testdata.zip")
|
99
|
+
|
100
|
+
try:
|
101
|
+
_download_with_progress(url, zip_path)
|
102
|
+
print("📦 Extracting...")
|
103
|
+
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
104
|
+
zip_ref.extractall(base_dir)
|
105
|
+
print(f"✅ Extracted to {target_dir}")
|
106
|
+
except Exception as e:
|
107
|
+
raise RuntimeError(f"❌ Download or extraction failed: {e}")
|
108
|
+
|
109
|
+
if not _validate_required_files(target_dir, required_files):
|
110
|
+
raise RuntimeError(f"❌ Downloaded data is missing required files in {target_dir}")
|
111
|
+
return target_dir
|
112
|
+
|
113
|
+
# Timeout handler for POSIX
|
114
|
+
def timeout_handler(signum, frame):
|
115
|
+
raise TimeoutError("⏳ No input received in time.")
|
116
|
+
|
117
|
+
if os.name == 'posix':
|
118
|
+
signal.signal(signal.SIGALRM, timeout_handler)
|
119
|
+
|
120
|
+
# Manual user prompt
|
121
|
+
print("🔍 Could not find valid data. You may enter a directory manually.")
|
122
|
+
print("🔗 Dataset info: https://figshare.com/articles/dataset/ANTsPyMM_testing_data/29391236")
|
123
|
+
|
124
|
+
for attempt in range(1, max_tries + 1):
|
125
|
+
try:
|
126
|
+
if os.name == 'posix':
|
127
|
+
signal.alarm(timeout)
|
128
|
+
user_input = input(f"⏱️ Attempt {attempt}/{max_tries} — Enter data directory (or 'q' to quit): ").strip()
|
129
|
+
if os.name == 'posix':
|
130
|
+
signal.alarm(0)
|
131
|
+
|
132
|
+
if user_input.lower() == 'q':
|
133
|
+
break
|
134
|
+
|
135
|
+
path = os.path.expanduser(user_input)
|
136
|
+
if os.path.isdir(path) and _validate_required_files(path, required_files):
|
137
|
+
print(f"✅ Using user-provided directory: {path}")
|
138
|
+
return path
|
139
|
+
else:
|
140
|
+
print("❌ Invalid or incomplete directory.")
|
141
|
+
|
142
|
+
except TimeoutError as e:
|
143
|
+
raise RuntimeError(str(e))
|
144
|
+
except KeyboardInterrupt:
|
145
|
+
raise RuntimeError("User interrupted execution. Exiting.")
|
146
|
+
|
147
|
+
raise RuntimeError("❗ No valid data directory found and download not permitted.")
|
148
|
+
|
149
|
+
candidate_rdirs = [
|
150
|
+
"~/Downloads/nrgdata_test/",
|
151
|
+
"~/Downloads/temp/nrgdata_test/",
|
152
|
+
"~/nrgdata_test/",
|
153
|
+
"~/data/ppmi/nrgdata_test/",
|
154
|
+
"/mnt/data/ppmi_testing/nrgdata_test/"]
|
155
|
+
|
156
|
+
|
157
|
+
rdir = find_data_dir( candidate_rdirs, allow_download="~/Downloads" )
|
158
|
+
print(f"Using data directory: {rdir}")
|
159
|
+
|
160
|
+
nthreads = str(8)
|
161
|
+
os.environ["TF_NUM_INTEROP_THREADS"] = nthreads
|
162
|
+
os.environ["TF_NUM_INTRAOP_THREADS"] = nthreads
|
163
|
+
os.environ["ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS"] = nthreads
|
164
|
+
os.environ["OPENBLAS_NUM_THREADS"] = nthreads
|
165
|
+
os.environ["MKL_NUM_THREADS"] = nthreads
|
166
|
+
import numpy as np
|
167
|
+
import glob as glob
|
168
|
+
import antspymm
|
169
|
+
import ants
|
170
|
+
import random
|
171
|
+
import re
|
172
|
+
|
173
|
+
def read_bvecs_rotated(bvec_file, rotmat):
|
174
|
+
bvecs = np.loadtxt(bvec_file)
|
175
|
+
if bvecs.shape[0] != 3:
|
176
|
+
bvecs = bvecs.T
|
177
|
+
rotated_bvecs = (rotmat @ bvecs).T
|
178
|
+
return rotated_bvecs
|
179
|
+
|
180
|
+
def mean_rgb_correlation(img1, img2, mask):
|
181
|
+
"""
|
182
|
+
Compute the mean correlation between two RGB images.
|
183
|
+
|
184
|
+
Parameters
|
185
|
+
----------
|
186
|
+
img1 : np.ndarray
|
187
|
+
First RGB image as a (H, W, 3) NumPy array.
|
188
|
+
img2 : np.ndarray
|
189
|
+
Second RGB image as a (H, W, 3) NumPy array.
|
190
|
+
|
191
|
+
Returns
|
192
|
+
-------
|
193
|
+
float
|
194
|
+
Mean Pearson correlation across the three RGB channels.
|
195
|
+
"""
|
196
|
+
if img1.shape != img2.shape:
|
197
|
+
raise ValueError("Input images must have the same shape.")
|
198
|
+
correlations = []
|
199
|
+
img1c = ants.split_channels(img1)
|
200
|
+
img2c = ants.split_channels(img2)
|
201
|
+
for c in range(3): # R, G, B
|
202
|
+
x = extract_masked_values( img1c[c], mask)
|
203
|
+
y = extract_masked_values( img2c[c], mask)
|
204
|
+
if np.std(x) == 0 or np.std(y) == 0:
|
205
|
+
corr = 0.0 # Handle flat images
|
206
|
+
else:
|
207
|
+
corr, _ = pearsonr(x, y)
|
208
|
+
correlations.append(corr)
|
209
|
+
return np.mean(correlations)
|
210
|
+
|
211
|
+
import numpy as np
|
212
|
+
import ants
|
213
|
+
|
214
|
+
def mean_rgb_mae(img1, img2, mask):
|
215
|
+
"""
|
216
|
+
Compute the mean absolute error (MAE) between two RGB images.
|
217
|
+
|
218
|
+
Parameters
|
219
|
+
----------
|
220
|
+
img1 : np.ndarray
|
221
|
+
First RGB image as a (H, W, 3) NumPy array.
|
222
|
+
img2 : np.ndarray
|
223
|
+
Second RGB image as a (H, W, 3) NumPy array.
|
224
|
+
mask : ants.ANTsImage
|
225
|
+
Binary mask defining valid pixels for error calculation.
|
226
|
+
|
227
|
+
Returns
|
228
|
+
-------
|
229
|
+
float
|
230
|
+
Mean absolute error across the three RGB channels.
|
231
|
+
"""
|
232
|
+
if img1.shape != img2.shape:
|
233
|
+
raise ValueError("Input images must have the same shape.")
|
234
|
+
|
235
|
+
mae_values = []
|
236
|
+
img1c = ants.split_channels(img1)
|
237
|
+
img2c = ants.split_channels(img2)
|
238
|
+
|
239
|
+
for c in range(3): # R, G, B
|
240
|
+
x = extract_masked_values(img1c[c], mask)
|
241
|
+
y = extract_masked_values(img2c[c], mask)
|
242
|
+
mae = np.mean(np.abs(x - y))
|
243
|
+
mae_values.append(mae)
|
244
|
+
|
245
|
+
return np.mean(mae_values)
|
246
|
+
|
247
|
+
def extract_masked_values(image, mask):
|
248
|
+
return image.numpy()[mask.numpy() > 0]
|
249
|
+
|
250
|
+
import numpy as np
|
251
|
+
|
252
|
+
def verify_unit_bvecs(bvecs, tol=1e-5):
|
253
|
+
"""
|
254
|
+
Verifies that each b-vector has unit norm within a tolerance.
|
255
|
+
|
256
|
+
Parameters
|
257
|
+
----------
|
258
|
+
bvecs : array-like, shape (N, 3)
|
259
|
+
Array of b-vectors, one per diffusion direction.
|
260
|
+
tol : float
|
261
|
+
Tolerance for unit norm.
|
262
|
+
|
263
|
+
Returns
|
264
|
+
-------
|
265
|
+
is_unit : np.ndarray, shape (N,)
|
266
|
+
Boolean array indicating if each b-vector is unit norm.
|
267
|
+
norms : np.ndarray, shape (N,)
|
268
|
+
Norms of each b-vector.
|
269
|
+
"""
|
270
|
+
bvecs = np.asarray(bvecs)
|
271
|
+
norms = np.linalg.norm(bvecs, axis=1)
|
272
|
+
is_unit = np.abs(norms - 1) < tol
|
273
|
+
return is_unit, norms
|
274
|
+
|
275
|
+
mydir = rdir + "PPMI/"
|
276
|
+
outdir = re.sub( 'nrgdata_test', 'antspymmoutput', rdir )
|
277
|
+
import glob as glob
|
278
|
+
|
279
|
+
t1fn=glob.glob(mydir+"101018/20210412/T1w/1496225/*.nii.gz")
|
280
|
+
if len(t1fn) > 0:
|
281
|
+
t1fn=t1fn[0]
|
282
|
+
print("Begin " + t1fn)
|
283
|
+
dtfn=glob.glob(mydir+"101018/20210412/DTI*/*/*.nii.gz")
|
284
|
+
dtfn.sort()
|
285
|
+
|
286
|
+
import re
|
287
|
+
|
288
|
+
# def test_efficient_dwi_fit_voxelwise_distortion_correction():
|
289
|
+
if len(dtfn) > 0:
|
290
|
+
img_LR_in = ants.image_read(dtfn[0])
|
291
|
+
img_LR_in_avg = ants.get_average_of_timeseries( img_LR_in )
|
292
|
+
mask = img_LR_in_avg.get_mask()
|
293
|
+
bvalfn = re.sub( 'nii.gz', 'bval', dtfn[0] )
|
294
|
+
bvecfn = re.sub( 'nii.gz', 'bvec', dtfn[0] )
|
295
|
+
if not exists(bvalfn) or not exists(bvecfn):
|
296
|
+
raise RuntimeError(f"Required bval/bvec files not found: {bvalfn}, {bvecfn}")
|
297
|
+
print(f"📁 Loading subject LR data from {bvalfn} ")
|
298
|
+
bvals, bvecs = read_bvals_bvecs(bvalfn, bvecfn)
|
299
|
+
bvecs = np.asarray(bvecs)
|
300
|
+
shape = img_LR_in.shape[:3]
|
301
|
+
|
302
|
+
print("📁 Loading subject T1 data...")
|
303
|
+
t1w = ants.image_read(t1fn)
|
304
|
+
t1w = ants.resample_image(t1w, [2, 2, 2], use_voxels=False)
|
305
|
+
bxt = antspynet.brain_extraction(t1w, modality='t1', verbose=False).threshold_image(0.5, 1.5)
|
306
|
+
|
307
|
+
if not "mytx" in globals():
|
308
|
+
dwianat = ants.slice_image( img_LR_in, idx=0, axis=3)
|
309
|
+
mytx = ants.registration( t1w, dwianat, 'SyNCC', syn_metric='CC', syn_sampling=2, total_sigma=0.5 )
|
310
|
+
mytx2 = ants.apply_transforms(t1w, img_LR_in_avg, mytx['fwdtransforms'],
|
311
|
+
interpolator='linear', imagetype=0, compose='/tmp/comptxDT2T1' )
|
312
|
+
print( mytx2 )
|
313
|
+
|
314
|
+
print("🔄 now with distortion correction...")
|
315
|
+
mydef = ants.image_read( mytx2 )
|
316
|
+
mywarp = ants.transform_from_displacement_field( mydef )
|
317
|
+
img_w = antspymm.timeseries_transform(mywarp, img_LR_in, reference=t1w)
|
318
|
+
print("🧠 Running warped fit...")
|
319
|
+
|
320
|
+
if not "FA_w" in globals():
|
321
|
+
mydefgrad = antspymm.deformation_gradient_optimized( mydef,
|
322
|
+
to_rotation=False, to_inverse_rotation=True )
|
323
|
+
bvecsdc = antspymm.distortion_correct_bvecs( bvecs, mydefgrad, t1w.direction, img_LR_in_avg.direction )
|
324
|
+
FA_w, MD_w, RGB_w = antspymm.efficient_dwi_fit_voxelwise(
|
325
|
+
imagein=img_w,
|
326
|
+
maskin=bxt,
|
327
|
+
bvals=bvals,
|
328
|
+
bvecs_5d=bvecsdc,
|
329
|
+
model_params={},
|
330
|
+
bvals_to_use=None,
|
331
|
+
num_threads=nt,
|
332
|
+
verbose=False
|
333
|
+
)
|
334
|
+
|
335
|
+
if not "FA_w2" in globals():
|
336
|
+
FA_w2, MD_w2, RGB_w2 = antspymm.efficient_dwi_fit_voxelwise(
|
337
|
+
imagein=img_w,
|
338
|
+
maskin=bxt,
|
339
|
+
bvals=bvals,
|
340
|
+
bvecs_5d=broadcast_bvecs_voxelwise(bvecs, t1w.shape),
|
341
|
+
model_params={},
|
342
|
+
bvals_to_use=None,
|
343
|
+
num_threads=nt,
|
344
|
+
verbose=False
|
345
|
+
)
|
346
|
+
|
347
|
+
print("📊 Comparing results...")
|
348
|
+
maske=ants.iMath(bxt,'ME',3)
|
349
|
+
fa_corr = mean_rgb_correlation( RGB_w, RGB_w2, maske )
|
350
|
+
print(f"✅ Direction-weighted FA correlation (original vs distortion corrected): {fa_corr:.4f}")
|
351
|
+
|
352
|
+
ants.image_write( t1w, '/tmp/t1w.nii.gz' )
|
353
|
+
ants.image_write( RGB_w, '/tmp/rgbw.nii.gz' )
|
@@ -0,0 +1,282 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import ants
|
3
|
+
import os
|
4
|
+
from dipy.io.gradients import read_bvals_bvecs
|
5
|
+
from scipy.stats import pearsonr
|
6
|
+
import antspymm
|
7
|
+
import matplotlib.pyplot as plt
|
8
|
+
nt = 2
|
9
|
+
# amount or rotation around x, y, z-axis
|
10
|
+
degrotx = 15
|
11
|
+
degroty = 15
|
12
|
+
degrotz = 15
|
13
|
+
|
14
|
+
|
15
|
+
def mean_rgb_correlation(img1, img2, mask):
|
16
|
+
"""
|
17
|
+
Compute the mean correlation between two RGB images.
|
18
|
+
|
19
|
+
Parameters
|
20
|
+
----------
|
21
|
+
img1 : np.ndarray
|
22
|
+
First RGB image as a (H, W, 3) NumPy array.
|
23
|
+
img2 : np.ndarray
|
24
|
+
Second RGB image as a (H, W, 3) NumPy array.
|
25
|
+
|
26
|
+
Returns
|
27
|
+
-------
|
28
|
+
float
|
29
|
+
Mean Pearson correlation across the three RGB channels.
|
30
|
+
"""
|
31
|
+
if img1.shape != img2.shape:
|
32
|
+
raise ValueError("Input images must have the same shape.")
|
33
|
+
correlations = []
|
34
|
+
img1c = ants.split_channels(img1)
|
35
|
+
img2c = ants.split_channels(img2)
|
36
|
+
for c in range(3): # R, G, B
|
37
|
+
x = extract_masked_values( img1c[c], mask)
|
38
|
+
y = extract_masked_values( img2c[c], mask)
|
39
|
+
if np.std(x) == 0 or np.std(y) == 0:
|
40
|
+
corr = 0.0 # Handle flat images
|
41
|
+
else:
|
42
|
+
corr, _ = pearsonr(x, y)
|
43
|
+
correlations.append(corr)
|
44
|
+
return np.mean(correlations)
|
45
|
+
|
46
|
+
|
47
|
+
def plot_correlation(x, y, xlabel="Image 1", ylabel="Image 2", title_prefix="Correlation", point_alpha=0.3):
|
48
|
+
"""
|
49
|
+
Plot the correlation between two 1D arrays and display Pearson r.
|
50
|
+
|
51
|
+
Parameters
|
52
|
+
----------
|
53
|
+
x : array-like
|
54
|
+
First set of values.
|
55
|
+
y : array-like
|
56
|
+
Second set of values.
|
57
|
+
xlabel : str
|
58
|
+
Label for the x-axis.
|
59
|
+
ylabel : str
|
60
|
+
Label for the y-axis.
|
61
|
+
title_prefix : str
|
62
|
+
Title prefix, will append r value.
|
63
|
+
point_alpha : float
|
64
|
+
Transparency of scatter plot points.
|
65
|
+
|
66
|
+
Returns
|
67
|
+
-------
|
68
|
+
float
|
69
|
+
Pearson correlation coefficient (r).
|
70
|
+
"""
|
71
|
+
x = np.asarray(x)
|
72
|
+
y = np.asarray(y)
|
73
|
+
|
74
|
+
if x.shape != y.shape:
|
75
|
+
raise ValueError("Inputs must have the same shape.")
|
76
|
+
|
77
|
+
r, _ = pearsonr(x, y)
|
78
|
+
|
79
|
+
plt.figure(figsize=(6, 6))
|
80
|
+
plt.scatter(x, y, s=1, alpha=point_alpha, color='blue')
|
81
|
+
plt.xlabel(xlabel)
|
82
|
+
plt.ylabel(ylabel)
|
83
|
+
plt.title(f"{title_prefix} (r = {r:.3f})")
|
84
|
+
plt.plot([x.min(), x.max()], [x.min(), x.max()], 'r--', linewidth=1) # Identity line
|
85
|
+
plt.grid(True)
|
86
|
+
plt.axis('equal')
|
87
|
+
plt.tight_layout()
|
88
|
+
plt.show()
|
89
|
+
|
90
|
+
return r
|
91
|
+
|
92
|
+
def read_bvecs_rotated(bvec_file, rotmat):
|
93
|
+
bvecs = np.loadtxt(bvec_file)
|
94
|
+
if bvecs.shape[0] != 3:
|
95
|
+
bvecs = bvecs.T
|
96
|
+
rotated_bvecs = (rotmat @ bvecs).T
|
97
|
+
return rotated_bvecs
|
98
|
+
|
99
|
+
def broadcast_bvecs_voxelwise(rotated_bvecs, shape):
|
100
|
+
return np.broadcast_to(rotated_bvecs, shape + rotated_bvecs.shape).copy()
|
101
|
+
|
102
|
+
def extract_masked_values(image, mask):
|
103
|
+
return image.numpy()[mask.numpy() > 0]
|
104
|
+
|
105
|
+
# def test_efficient_dwi_fit_voxelwise_rotation_consistency():
|
106
|
+
if True:
|
107
|
+
print("simple test for rotation consistency of efficient_dwi_fit_voxelwise")
|
108
|
+
ex_path = os.path.expanduser( "~/.antspyt1w/" )
|
109
|
+
ex_path_mm = os.path.expanduser( "~/.antspymm/" )
|
110
|
+
JHU_atlas = ants.image_read( ex_path + 'JHU-ICBM-FA-1mm.nii.gz' ) # Read in JHU atlas
|
111
|
+
JHU_labels = ants.image_read( ex_path + 'JHU-ICBM-labels-1mm.nii.gz' ) # Read in JHU labels
|
112
|
+
#### Load in data ####
|
113
|
+
print("Load in subject data ...")
|
114
|
+
lrid = ex_path_mm + "I1499279_Anon_20210819142214_5"
|
115
|
+
rlid = ex_path_mm + "I1499337_Anon_20210819142214_6"
|
116
|
+
t1id = ex_path_mm + "t1_rand.nii.gz"
|
117
|
+
# Load paths
|
118
|
+
print("📁 Loading subject data...")
|
119
|
+
lrid = os.path.join(ex_path_mm, "I1499279_Anon_20210819142214_5")
|
120
|
+
img_LR_in = ants.image_read(lrid + '.nii.gz')
|
121
|
+
img_LR_in_avg = ants.get_average_of_timeseries( img_LR_in )
|
122
|
+
mask = img_LR_in_avg.get_mask()
|
123
|
+
|
124
|
+
bvals, bvecs = read_bvals_bvecs(lrid + '.bval', lrid + '.bvec')
|
125
|
+
bvecs = np.asarray(bvecs)
|
126
|
+
shape = img_LR_in.shape[:3]
|
127
|
+
|
128
|
+
print("🧠 Running baseline (unrotated) fit...")
|
129
|
+
bvecs_5d_orig = np.broadcast_to(bvecs, shape + bvecs.shape).copy()
|
130
|
+
if not "FA_orig" in globals():
|
131
|
+
FA_orig, MD_orig, RGB_orig = antspymm.efficient_dwi_fit_voxelwise(
|
132
|
+
imagein=img_LR_in,
|
133
|
+
maskin=mask,
|
134
|
+
bvals=bvals,
|
135
|
+
bvecs_5d=bvecs_5d_orig,
|
136
|
+
model_params={},
|
137
|
+
bvals_to_use=None,
|
138
|
+
num_threads=nt,
|
139
|
+
verbose=False
|
140
|
+
)
|
141
|
+
|
142
|
+
print("🔄 Applying known rotation...")
|
143
|
+
# maxtrans = 10.0
|
144
|
+
# rotator = ants.contrib.RandomRotate3D((-maxtrans, maxtrans), reference=img_LR_in_avg)
|
145
|
+
rotator=ants.contrib.Rotate3D(rotation=(degrotx,degroty,degrotz), reference=img_LR_in_avg )
|
146
|
+
rotation = rotator.transform()
|
147
|
+
img_rotated_ref = ants.apply_ants_transform_to_image(rotation, img_LR_in_avg, reference=img_LR_in_avg)
|
148
|
+
img_rotated = antspymm.timeseries_transform(rotation, img_LR_in, reference=img_LR_in_avg)
|
149
|
+
mask_rotated = ants.apply_ants_transform_to_image(rotation, mask, reference=img_LR_in_avg, interpolation='nearestNeighbor')
|
150
|
+
# note: we apply the inverse rotation to the bvecs
|
151
|
+
# i.e. if we register A to B and get R then apply R_inv to bvecs
|
152
|
+
rotmat = ants.get_ants_transform_parameters(rotation.invert()).reshape((4, 3))[:3, :3]
|
153
|
+
bvecs_rotated = read_bvecs_rotated(lrid + '.bvec', rotmat)
|
154
|
+
bvecs_5d_rot = broadcast_bvecs_voxelwise(bvecs_rotated, shape)
|
155
|
+
|
156
|
+
print("🧠 Running rotated fit...")
|
157
|
+
if not "FA_rot" in globals():
|
158
|
+
FA_rot, MD_rot, RGB_rot = antspymm.efficient_dwi_fit_voxelwise(
|
159
|
+
imagein=img_rotated,
|
160
|
+
maskin=mask_rotated,
|
161
|
+
bvals=bvals,
|
162
|
+
bvecs_5d=bvecs_5d_rot,
|
163
|
+
model_params={},
|
164
|
+
bvals_to_use=None,
|
165
|
+
num_threads=nt,
|
166
|
+
verbose=False
|
167
|
+
)
|
168
|
+
|
169
|
+
FA_rot_back = ants.apply_ants_transform_to_image(
|
170
|
+
rotation.invert(),
|
171
|
+
FA_rot,
|
172
|
+
reference = img_LR_in_avg,
|
173
|
+
)
|
174
|
+
MD_rot_back = ants.apply_ants_transform_to_image(
|
175
|
+
rotation.invert(),
|
176
|
+
MD_rot,
|
177
|
+
reference = img_LR_in_avg,
|
178
|
+
)
|
179
|
+
print("📊 Comparing results...")
|
180
|
+
maske=ants.iMath(mask,'ME',2)
|
181
|
+
# smoothing simulates double interpolation
|
182
|
+
FA_origs = ants.smooth_image(FA_orig, 1.25)
|
183
|
+
MD_origs = ants.smooth_image(MD_orig, 1.25)
|
184
|
+
fa1 = extract_masked_values(FA_origs, maske)
|
185
|
+
fa2 = extract_masked_values(FA_rot_back, maske)
|
186
|
+
md1 = extract_masked_values(MD_origs, maske)
|
187
|
+
md2 = extract_masked_values(MD_rot_back, maske)
|
188
|
+
|
189
|
+
fa_corr, _ = pearsonr(fa1, fa2)
|
190
|
+
md_corr, _ = pearsonr(md1, md2)
|
191
|
+
|
192
|
+
# plot_correlation( fa1, fa2 )
|
193
|
+
|
194
|
+
print(f"✅ FA correlation (original vs rotated): {fa_corr:.4f}")
|
195
|
+
print(f"✅ MD correlation (original vs rotated): {md_corr:.4f}")
|
196
|
+
|
197
|
+
assert fa_corr > 0.80, "FA correlation too low"
|
198
|
+
assert md_corr > 0.80, "MD correlation too low"
|
199
|
+
|
200
|
+
print("🎉 Test passed: model is rotation-consistent with voxelwise bvecs.")
|
201
|
+
|
202
|
+
print("This shows that the simulation is effective.")
|
203
|
+
print("Now use the simulated data to test the distortion correction...")
|
204
|
+
# ants.image_write( RGB_orig, '/tmp/temp0rgb.nii.gz' )
|
205
|
+
# ants.image_write( RGB_rot, '/tmp/temp1rgb.nii.gz' )
|
206
|
+
|
207
|
+
|
208
|
+
# now we have to map the img_rotated and its bvec_rotated partner
|
209
|
+
# as we would with a generic distortion correction framework.
|
210
|
+
# first --- write the transform to a file
|
211
|
+
rotationinv = rotation.invert()
|
212
|
+
|
213
|
+
import tempfile
|
214
|
+
with tempfile.TemporaryDirectory() as tempdir:
|
215
|
+
rotmatfile = os.path.join(tempdir, "rotation.mat")
|
216
|
+
compositefile = os.path.join(tempdir, "composite")
|
217
|
+
|
218
|
+
# Write the rotation transform
|
219
|
+
ants.write_transform(rotationinv, rotmatfile)
|
220
|
+
|
221
|
+
# Run registration if not already done
|
222
|
+
if "reg" not in globals():
|
223
|
+
reg = ants.registration(FA_orig, FA_rot, 'SyN', initial_transform=rotmatfile)
|
224
|
+
|
225
|
+
# Apply the transformation using a temporary composite path
|
226
|
+
comptx = ants.apply_transforms(
|
227
|
+
img_LR_in_avg,
|
228
|
+
img_rotated_ref,
|
229
|
+
reg['fwdtransforms'],
|
230
|
+
interpolator='linear',
|
231
|
+
compose=compositefile,
|
232
|
+
verbose=True
|
233
|
+
)
|
234
|
+
mydef=ants.image_read(comptx)
|
235
|
+
mydefgrad = antspymm.deformation_gradient_optimized( mydef,
|
236
|
+
to_rotation=False, to_inverse_rotation=True )
|
237
|
+
# bvecsRLw = antspymm.generate_voxelwise_bvecs( bvecs_rotated, mydefgrad, transpose=False )
|
238
|
+
img_rotated_avg = ants.get_average_of_timeseries( img_rotated )
|
239
|
+
bvecsRLw = antspymm.distortion_correct_bvecs( bvecs_rotated, mydefgrad, img_LR_in_avg.direction, img_rotated_avg.direction )
|
240
|
+
mywarp = ants.transform_from_displacement_field( mydef )
|
241
|
+
img_w = antspymm.timeseries_transform(mywarp, img_rotated, reference=img_rotated_avg )
|
242
|
+
|
243
|
+
correlations = []
|
244
|
+
labels = ["NoBvecReo", "BvecReo"]
|
245
|
+
|
246
|
+
RGB_origs = ants.smooth_image(RGB_orig, 1.0)
|
247
|
+
|
248
|
+
for label, bv in zip(labels, [bvecs_5d_rot, bvecsRLw]):
|
249
|
+
FA_w, MD_w, RGB_w = antspymm.efficient_dwi_fit_voxelwise(
|
250
|
+
imagein=img_w,
|
251
|
+
maskin=ants.get_mask(ants.get_average_of_timeseries(img_w)),
|
252
|
+
bvals=bvals,
|
253
|
+
bvecs_5d=bv,
|
254
|
+
model_params={},
|
255
|
+
bvals_to_use=None,
|
256
|
+
num_threads=nt,
|
257
|
+
verbose=False
|
258
|
+
)
|
259
|
+
# Optional: write RGB images to temporary files for inspection
|
260
|
+
rgb_orig_path = os.path.join(tempdir, "RGB_orig.nii.gz")
|
261
|
+
rgb_warped_path = os.path.join(tempdir, f"RGB_warped_{label}.nii.gz")
|
262
|
+
|
263
|
+
ants.image_write(RGB_origs, rgb_orig_path)
|
264
|
+
ants.image_write(RGB_w, rgb_warped_path)
|
265
|
+
|
266
|
+
print(f"📝 Saved RGB_orig to: {rgb_orig_path}")
|
267
|
+
print(f"📝 Saved RGB_warped ({label}) to: {rgb_warped_path}")
|
268
|
+
|
269
|
+
compmask = ants.get_mask(ants.get_average_of_timeseries(img_w))
|
270
|
+
compmask = ants.iMath(compmask, 'ME', 2)
|
271
|
+
|
272
|
+
corr = mean_rgb_correlation(RGB_origs, RGB_w, compmask)
|
273
|
+
print(f"{label:10s} correlation: {corr:.4f}")
|
274
|
+
correlations.append(corr)
|
275
|
+
|
276
|
+
# Compare the two correlations
|
277
|
+
if len(correlations) == 2:
|
278
|
+
print("\nComparison Result:")
|
279
|
+
if correlations[1] > correlations[0]:
|
280
|
+
print("✅ BvecReo gives higher correlation than NoBvecReo.")
|
281
|
+
else:
|
282
|
+
print("❌ NoBvecReo gives equal or higher correlation than BvecReo.")
|
@@ -0,0 +1,118 @@
|
|
1
|
+
import numpy as np
|
2
|
+
from scipy.spatial.transform import Rotation as R
|
3
|
+
import antspymm
|
4
|
+
|
5
|
+
def test_generate_voxelwise_bvecs():
|
6
|
+
# Define synthetic bvecs (N = 3 directions)
|
7
|
+
bvecs = np.array([
|
8
|
+
[1, 0, 0], # x-axis
|
9
|
+
[0, 1, 0], # y-axis
|
10
|
+
[0, 0, 1], # z-axis
|
11
|
+
])
|
12
|
+
|
13
|
+
# Define image shape (small synthetic 4D volume)
|
14
|
+
X, Y, Z = 2, 2, 2
|
15
|
+
N = bvecs.shape[0]
|
16
|
+
|
17
|
+
# Define a known rotation (e.g., 90-degree rotation around z-axis)
|
18
|
+
rot = R.from_euler('z', 33, degrees=True).as_matrix() # shape (3, 3)
|
19
|
+
|
20
|
+
# Create a voxelwise rotation field (X, Y, Z, 3, 3), same rot everywhere
|
21
|
+
voxel_rotations = np.tile(rot, (X, Y, Z, 1, 1))
|
22
|
+
voxel_rotations = voxel_rotations.reshape(X, Y, Z, 3, 3)
|
23
|
+
|
24
|
+
# Expected rotated bvecs
|
25
|
+
expected = np.dot(voxel_rotations[0, 0, 0], bvecs.T).T # (N, 3)
|
26
|
+
|
27
|
+
# Call the function under test
|
28
|
+
bvecs_5d = antspymm.generate_voxelwise_bvecs(bvecs, voxel_rotations) # shape (X, Y, Z, N, 3)
|
29
|
+
|
30
|
+
# Check that all voxel outputs match expected result
|
31
|
+
for i in range(X):
|
32
|
+
for j in range(Y):
|
33
|
+
for k in range(Z):
|
34
|
+
actual = bvecs_5d[i, j, k]
|
35
|
+
assert np.allclose(actual, expected, atol=1e-6), \
|
36
|
+
f"Mismatch at voxel {(i, j, k)}: {actual} vs {expected}"
|
37
|
+
|
38
|
+
print("✅ test_generate_voxelwise_bvecs passed!")
|
39
|
+
|
40
|
+
# Run the test
|
41
|
+
print("Running test for generate_voxelwise_bvecs...")
|
42
|
+
test_generate_voxelwise_bvecs()
|
43
|
+
|
44
|
+
|
45
|
+
import numpy as np
|
46
|
+
import ants
|
47
|
+
from scipy.spatial.transform import Rotation as R
|
48
|
+
|
49
|
+
def generate_dummy_dwi_data(shape, n_volumes):
|
50
|
+
"""
|
51
|
+
Generate synthetic DWI-like 4D data and a brain mask.
|
52
|
+
"""
|
53
|
+
np.random.seed(42)
|
54
|
+
dwi = np.random.rand(*shape, n_volumes).astype(np.float32)
|
55
|
+
mask = np.ones(shape, dtype=np.uint8)
|
56
|
+
return dwi, mask
|
57
|
+
|
58
|
+
def create_voxelwise_bvecs(shape, bvecs, rotation_matrix=None):
|
59
|
+
"""
|
60
|
+
Create voxelwise (5D) bvecs array (X, Y, Z, N, 3), optionally rotated.
|
61
|
+
"""
|
62
|
+
if rotation_matrix is not None:
|
63
|
+
rotated_bvecs = (rotation_matrix @ bvecs.T).T
|
64
|
+
else:
|
65
|
+
rotated_bvecs = bvecs
|
66
|
+
|
67
|
+
# Broadcast to 5D shape
|
68
|
+
bvecs_5d = np.broadcast_to(rotated_bvecs, shape + rotated_bvecs.shape)
|
69
|
+
return bvecs_5d.copy()
|
70
|
+
|
71
|
+
def test_efficient_dwi_fit_voxelwise():
|
72
|
+
|
73
|
+
# Parameters
|
74
|
+
shape = (3, 3, 3)
|
75
|
+
n_vols = 6
|
76
|
+
|
77
|
+
# Synthetic data
|
78
|
+
dwi_data, mask_data = generate_dummy_dwi_data(shape, n_vols)
|
79
|
+
ants_dwi = ants.from_numpy(dwi_data)
|
80
|
+
ants_mask = ants.from_numpy(mask_data)
|
81
|
+
|
82
|
+
# Define bvals and bvecs
|
83
|
+
bvals = np.array([0, 1000, 1000, 1000, 1000, 1000])
|
84
|
+
bvecs = np.array([
|
85
|
+
[0, 0, 0],
|
86
|
+
[1, 0, 0],
|
87
|
+
[0, 1, 0],
|
88
|
+
[0, 0, 1],
|
89
|
+
[1/np.sqrt(2), 1/np.sqrt(2), 0],
|
90
|
+
[1/np.sqrt(2), 0, 1/np.sqrt(2)],
|
91
|
+
])
|
92
|
+
|
93
|
+
# Optional: apply rotation
|
94
|
+
rotation_matrix = R.from_euler('z', 45, degrees=True).as_matrix()
|
95
|
+
bvecs_5d = create_voxelwise_bvecs(shape, bvecs, rotation_matrix)
|
96
|
+
|
97
|
+
# Call function under test
|
98
|
+
FA_img, MD_img, RGB_img = antspymm.efficient_dwi_fit_voxelwise(
|
99
|
+
imagein=ants_dwi,
|
100
|
+
maskin=ants_mask,
|
101
|
+
bvals=bvals,
|
102
|
+
bvecs_5d=bvecs_5d,
|
103
|
+
model_params={},
|
104
|
+
bvals_to_use=None,
|
105
|
+
num_threads=1,
|
106
|
+
verbose=False
|
107
|
+
)
|
108
|
+
|
109
|
+
# Tests
|
110
|
+
assert isinstance(FA_img, ants.ANTsImage), "FA_img should be an ANTsImage"
|
111
|
+
assert FA_img.shape == shape, f"FA image has shape {FA_img.shape}, expected {shape}"
|
112
|
+
assert np.all((FA_img.numpy() >= 0) & (FA_img.numpy() <= 1)), "FA values should be in [0, 1]"
|
113
|
+
assert MD_img.shape == shape, "MD image has incorrect shape"
|
114
|
+
|
115
|
+
print("✅ test_efficient_dwi_fit_voxelwise passed!")
|
116
|
+
|
117
|
+
# Run the test
|
118
|
+
test_efficient_dwi_fit_voxelwise()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|