antspymm 1.5.5__tar.gz → 1.5.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. {antspymm-1.5.5 → antspymm-1.5.7}/PKG-INFO +1 -1
  2. {antspymm-1.5.5 → antspymm-1.5.7}/antspymm/__init__.py +6 -0
  3. {antspymm-1.5.5 → antspymm-1.5.7}/antspymm/mm.py +246 -0
  4. {antspymm-1.5.5 → antspymm-1.5.7}/antspymm.egg-info/PKG-INFO +1 -1
  5. {antspymm-1.5.5 → antspymm-1.5.7}/antspymm.egg-info/SOURCES.txt +4 -1
  6. antspymm-1.5.7/docs/dti_distortion_correction_voxelwise_varying_bvectors_example.py +353 -0
  7. antspymm-1.5.7/docs/dti_reconstruction_voxelwise_varying_bvectors_localint.py +282 -0
  8. {antspymm-1.5.5 → antspymm-1.5.7}/pyproject.toml +1 -1
  9. antspymm-1.5.7/tests/voxelwise_bvec_dti_recon_test.py +118 -0
  10. {antspymm-1.5.5 → antspymm-1.5.7}/MANIFEST.in +0 -0
  11. {antspymm-1.5.5 → antspymm-1.5.7}/README.md +0 -0
  12. {antspymm-1.5.5 → antspymm-1.5.7}/antspymm.egg-info/dependency_links.txt +0 -0
  13. {antspymm-1.5.5 → antspymm-1.5.7}/antspymm.egg-info/requires.txt +0 -0
  14. {antspymm-1.5.5 → antspymm-1.5.7}/antspymm.egg-info/top_level.txt +0 -0
  15. {antspymm-1.5.5 → antspymm-1.5.7}/docs/adni_rsfmri_2_nrg_conversion.py +0 -0
  16. {antspymm-1.5.5 → antspymm-1.5.7}/docs/antspymm_annotated_output_tree.pages +0 -0
  17. {antspymm-1.5.5 → antspymm-1.5.7}/docs/antspymm_annotated_output_tree.txt +0 -0
  18. {antspymm-1.5.5 → antspymm-1.5.7}/docs/antspymm_data_dictionary.csv +0 -0
  19. {antspymm-1.5.5 → antspymm-1.5.7}/docs/aslprep_perfusion_run_localint.py +0 -0
  20. {antspymm-1.5.5 → antspymm-1.5.7}/docs/bids_2_nrg.py +0 -0
  21. {antspymm-1.5.5 → antspymm-1.5.7}/docs/bids_cohort_example.py +0 -0
  22. {antspymm-1.5.5 → antspymm-1.5.7}/docs/bind_mm_wide.R +0 -0
  23. {antspymm-1.5.5 → antspymm-1.5.7}/docs/blind_qc.Rmd +0 -0
  24. {antspymm-1.5.5 → antspymm-1.5.7}/docs/blind_qc.html +0 -0
  25. {antspymm-1.5.5 → antspymm-1.5.7}/docs/blind_qc.py +0 -0
  26. {antspymm-1.5.5 → antspymm-1.5.7}/docs/convert_adni_dti_to_nrg.R +0 -0
  27. {antspymm-1.5.5 → antspymm-1.5.7}/docs/deepnbm.jpg +0 -0
  28. {antspymm-1.5.5 → antspymm-1.5.7}/docs/deformation_gradient_reo.py +0 -0
  29. {antspymm-1.5.5 → antspymm-1.5.7}/docs/describe_mm_data.R +0 -0
  30. {antspymm-1.5.5 → antspymm-1.5.7}/docs/dipy_dti_recon.py +0 -0
  31. {antspymm-1.5.5 → antspymm-1.5.7}/docs/dti_recon.py +0 -0
  32. {antspymm-1.5.5 → antspymm-1.5.7}/docs/dti_reg.py +0 -0
  33. {antspymm-1.5.5 → antspymm-1.5.7}/docs/dwi_rebasing.py +0 -0
  34. {antspymm-1.5.5 → antspymm-1.5.7}/docs/dwi_run.py +0 -0
  35. {antspymm-1.5.5 → antspymm-1.5.7}/docs/dwi_run_ptbp_scrub.py +0 -0
  36. {antspymm-1.5.5 → antspymm-1.5.7}/docs/ex_rsfmri_run_minimal_ptbp.py +0 -0
  37. {antspymm-1.5.5 → antspymm-1.5.7}/docs/ex_sr.py +0 -0
  38. {antspymm-1.5.5 → antspymm-1.5.7}/docs/example_antspymm_output.csv +0 -0
  39. {antspymm-1.5.5 → antspymm-1.5.7}/docs/example_run_from_directory.py +0 -0
  40. {antspymm-1.5.5 → antspymm-1.5.7}/docs/flair_run_localint.py +0 -0
  41. {antspymm-1.5.5 → antspymm-1.5.7}/docs/joint_dti_recon_localint.py +0 -0
  42. {antspymm-1.5.5 → antspymm-1.5.7}/docs/make_dict_table.Rmd +0 -0
  43. {antspymm-1.5.5 → antspymm-1.5.7}/docs/make_dict_table.html +0 -0
  44. {antspymm-1.5.5 → antspymm-1.5.7}/docs/mm.py +0 -0
  45. {antspymm-1.5.5 → antspymm-1.5.7}/docs/mm_csv_ex_2.py +0 -0
  46. {antspymm-1.5.5 → antspymm-1.5.7}/docs/mm_csv_localint.py +0 -0
  47. {antspymm-1.5.5 → antspymm-1.5.7}/docs/mm_nrg.py +0 -0
  48. {antspymm-1.5.5 → antspymm-1.5.7}/docs/nrg_cohort_example.py +0 -0
  49. {antspymm-1.5.5 → antspymm-1.5.7}/docs/parallel_study_aggregation_example.py +0 -0
  50. {antspymm-1.5.5 → antspymm-1.5.7}/docs/perfusion_ptbp.py +0 -0
  51. {antspymm-1.5.5 → antspymm-1.5.7}/docs/perfusion_run_nnl.py +0 -0
  52. {antspymm-1.5.5 → antspymm-1.5.7}/docs/ppmi_step1_blind_qc.py +0 -0
  53. {antspymm-1.5.5 → antspymm-1.5.7}/docs/ppmi_step2_outlierness.py +0 -0
  54. {antspymm-1.5.5 → antspymm-1.5.7}/docs/ppmi_step3_mm_nrg_csv.py +0 -0
  55. {antspymm-1.5.5 → antspymm-1.5.7}/docs/ppmi_step4_aggregate.py +0 -0
  56. {antspymm-1.5.5 → antspymm-1.5.7}/docs/ptbp_nrg.py +0 -0
  57. {antspymm-1.5.5 → antspymm-1.5.7}/docs/roi_visualization.py +0 -0
  58. {antspymm-1.5.5 → antspymm-1.5.7}/docs/roi_visualization_ppmi.py +0 -0
  59. {antspymm-1.5.5 → antspymm-1.5.7}/docs/rsfmri_run_minimal_localint.py +0 -0
  60. {antspymm-1.5.5 → antspymm-1.5.7}/docs/run_local_integration_scripts.py +0 -0
  61. {antspymm-1.5.5 → antspymm-1.5.7}/docs/run_mm_example.sh +0 -0
  62. {antspymm-1.5.5 → antspymm-1.5.7}/docs/template_overlays.py +0 -0
  63. {antspymm-1.5.5 → antspymm-1.5.7}/docs/ukbb_rsfmri.py +0 -0
  64. {antspymm-1.5.5 → antspymm-1.5.7}/docs/ukbb_to_nrg_processing.py +0 -0
  65. {antspymm-1.5.5 → antspymm-1.5.7}/docs/ukbb_to_nrg_processing2.py +0 -0
  66. {antspymm-1.5.5 → antspymm-1.5.7}/docs/visualize_tractogram.py +0 -0
  67. {antspymm-1.5.5 → antspymm-1.5.7}/setup.cfg +0 -0
  68. {antspymm-1.5.5 → antspymm-1.5.7}/tests/test_loop.py +0 -0
  69. {antspymm-1.5.5 → antspymm-1.5.7}/tests/test_nrg_validation.py +0 -0
  70. {antspymm-1.5.5 → antspymm-1.5.7}/tests/test_reference_run.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: antspymm
3
- Version: 1.5.5
3
+ Version: 1.5.7
4
4
  Summary: multi-channel/time-series medical image processing with antspyx
5
5
  Author-email: "Avants, Gosselin, Tustison, Reardon" <stnava@gmail.com>
6
6
  License: Apache-2.0
@@ -65,6 +65,8 @@ from .mm import mc_denoise
65
65
  from .mm import mc_reg
66
66
  from .mm import dti_reg
67
67
  from .mm import timeseries_reg
68
+ from .mm import timeseries_transform
69
+ from .mm import copy_spatial_metadata_from_3d_to_4d
68
70
  from .mm import concat_dewarp
69
71
  from .mm import mc_resample_image_to_target
70
72
  from .mm import trim_dti_mask
@@ -136,5 +138,9 @@ from .mm import segment_timeseries_by_bvalue
136
138
  from .mm import shorten_pymm_names
137
139
  from .mm import pet3d_summary
138
140
  from .mm import deformation_gradient_optimized
141
+ from .mm import efficient_dwi_fit_voxelwise
142
+ from .mm import generate_voxelwise_bvecs
143
+ from .mm import distortion_correct_bvecs
144
+
139
145
 
140
146
 
@@ -1794,6 +1794,89 @@ def merge_timeseries_data( img_LR, img_RL, allow_resample=True ):
1794
1794
  mimg.append( temp )
1795
1795
  return ants.list_to_ndimage( img_LR, mimg )
1796
1796
 
1797
+ def copy_spatial_metadata_from_3d_to_4d(spatial_img, timeseries_img):
1798
+ """
1799
+ Copy spatial metadata (origin, spacing, direction) from a 3D image to the
1800
+ spatial dimensions (first 3) of a 4D image, preserving the 4th dimension's metadata.
1801
+
1802
+ Parameters
1803
+ ----------
1804
+ spatial_img : ants.ANTsImage
1805
+ A 3D ANTsImage with the desired spatial metadata.
1806
+ timeseries_img : ants.ANTsImage
1807
+ A 4D ANTsImage to update.
1808
+
1809
+ Returns
1810
+ -------
1811
+ ants.ANTsImage
1812
+ A 4D ANTsImage with updated spatial metadata.
1813
+ """
1814
+ if spatial_img.dimension != 3:
1815
+ raise ValueError("spatial_img must be a 3D ANTsImage.")
1816
+ if timeseries_img.dimension != 4:
1817
+ raise ValueError("timeseries_img must be a 4D ANTsImage.")
1818
+ # Get 3D metadata
1819
+ spatial_origin = list(spatial_img.origin)
1820
+ spatial_spacing = list(spatial_img.spacing)
1821
+ spatial_direction = spatial_img.direction # 3x3
1822
+ # Get original 4D metadata
1823
+ ts_spacing = list(timeseries_img.spacing)
1824
+ ts_origin = list(timeseries_img.origin)
1825
+ ts_direction = timeseries_img.direction # 4x4
1826
+ # Replace only the first 3 entries for origin and spacing
1827
+ new_origin = spatial_origin + [ts_origin[3]]
1828
+ new_spacing = spatial_spacing + [ts_spacing[3]]
1829
+ # Replace top-left 3x3 block of direction matrix, preserve last row/column
1830
+ new_direction = ts_direction.copy()
1831
+ new_direction[:3, :3] = spatial_direction
1832
+ # Create updated image
1833
+ updated_img = ants.from_numpy(
1834
+ timeseries_img.numpy(),
1835
+ origin=new_origin,
1836
+ spacing=new_spacing,
1837
+ direction=new_direction
1838
+ )
1839
+ return updated_img
1840
+
1841
+ def timeseries_transform(transform, image, reference, interpolation='linear'):
1842
+ """
1843
+ Apply a spatial transform to each 3D volume in a 4D time series image.
1844
+
1845
+ Parameters
1846
+ ----------
1847
+ transform : ants transform object
1848
+ Path(s) to ANTs-compatible transform(s) to apply.
1849
+ image : ants.ANTsImage
1850
+ 4D input image with shape (X, Y, Z, T).
1851
+ reference : ants.ANTsImage
1852
+ Reference image to match in space.
1853
+ interpolation : str
1854
+ Interpolation method: 'linear', 'nearestNeighbor', etc.
1855
+
1856
+ Returns
1857
+ -------
1858
+ ants.ANTsImage
1859
+ 4D transformed image.
1860
+ """
1861
+ if image.dimension != 4:
1862
+ raise ValueError("Input image must be 4D (X, Y, Z, T).")
1863
+ n_volumes = image.shape[3]
1864
+ transformed_volumes = []
1865
+ for t in range(n_volumes):
1866
+ vol = ants.slice_image( image, 3, t )
1867
+ transformed = ants.apply_ants_transform_to_image(
1868
+ transform=transform,
1869
+ image=vol,
1870
+ reference=reference,
1871
+ interpolation=interpolation
1872
+ )
1873
+ transformed_volumes.append(transformed.numpy())
1874
+ # Stack along time axis and convert to ANTsImage
1875
+ transformed_array = np.stack(transformed_volumes, axis=-1)
1876
+ out_image = ants.from_numpy(transformed_array)
1877
+ out_image = ants.copy_image_info(image, out_image)
1878
+ out_image = copy_spatial_metadata_from_3d_to_4d(reference, out_image)
1879
+ return out_image
1797
1880
 
1798
1881
  def timeseries_reg(
1799
1882
  image,
@@ -2055,6 +2138,40 @@ def bvec_reorientation( motion_parameters, bvecs, rebase=None ):
2055
2138
  bvecs[myidx,:] = np.dot( rebase, bvecs[myidx,:] )
2056
2139
  return bvecs
2057
2140
 
2141
+
2142
+ def distortion_correct_bvecs(bvecs, def_grad, A_img, A_ref):
2143
+ """
2144
+ Vectorized computation of voxel-wise distortion corrected b-vectors.
2145
+
2146
+ Parameters
2147
+ ----------
2148
+ bvecs : ndarray (N, 3)
2149
+ def_grad : ndarray (X, Y, Z, 3, 3) containing rotations derived from the deformation gradient
2150
+ A_img : ndarray (3, 3) direction matrix of the fixed image (target undistorted space)
2151
+ A_ref : ndarray (3, 3) direction matrix of the moving image (being corrected)
2152
+
2153
+ Returns
2154
+ -------
2155
+ bvecs_5d : ndarray (X, Y, Z, N, 3)
2156
+ """
2157
+ X, Y, Z = def_grad.shape[:3]
2158
+ N = bvecs.shape[0]
2159
+ # Combined rotation: R_voxel = A_ref.T @ A_img @ def_grad
2160
+ A = A_ref.T @ A_img
2161
+ R_voxel = np.einsum('ij,xyzjk->xyzik', A, def_grad) # (X, Y, Z, 3, 3)
2162
+ # Apply R_voxel.T @ bvecs
2163
+ # First, reshape R_voxel: (X*Y*Z, 3, 3)
2164
+ R_voxel_reshaped = R_voxel.reshape(-1, 3, 3)
2165
+ # Rotate all bvecs for each voxel
2166
+ # Output: (X*Y*Z, N, 3)
2167
+ rotated = np.einsum('vij,nj->vni', R_voxel_reshaped, bvecs)
2168
+ # Normalize
2169
+ norms = np.linalg.norm(rotated, axis=2, keepdims=True)
2170
+ rotated /= np.clip(norms, 1e-8, None)
2171
+ # Reshape back to (X, Y, Z, N, 3)
2172
+ bvecs_5d = rotated.reshape(X, Y, Z, N, 3)
2173
+ return bvecs_5d
2174
+
2058
2175
  def get_dti( reference_image, tensormodel, upper_triangular=True, return_image=False ):
2059
2176
  """
2060
2177
  extract DTI data from a dipy tensormodel
@@ -3932,6 +4049,135 @@ def efficient_dwi_fit(gtab, diffusion_model, imagein, maskin,
3932
4049
  return full_fit, FA_img, MD_img, RGB_img
3933
4050
 
3934
4051
 
4052
+ def efficient_dwi_fit_voxelwise(imagein, maskin, bvals, bvecs_5d, model_params=None,
4053
+ bvals_to_use=None, num_threads=1, verbose=True):
4054
+ """
4055
+ Voxel-wise diffusion model fitting with individual b-vectors per voxel.
4056
+
4057
+ Parameters
4058
+ ----------
4059
+ imagein : ants.ANTsImage
4060
+ 4D DWI image (X, Y, Z, N).
4061
+ maskin : ants.ANTsImage
4062
+ 3D binary mask.
4063
+ bvals : (N,) array-like
4064
+ Common b-values across volumes.
4065
+ bvecs_5d : (X, Y, Z, N, 3) ndarray
4066
+ Voxel-specific b-vectors.
4067
+ model_params : dict
4068
+ Extra arguments for model.
4069
+ bvals_to_use : list[int]
4070
+ Subset of b-values to include.
4071
+ num_threads : int
4072
+ Number of threads to use.
4073
+ verbose : bool
4074
+ Whether to print status.
4075
+
4076
+ Returns
4077
+ -------
4078
+ FA_img : ants.ANTsImage
4079
+ Fractional anisotropy.
4080
+ MD_img : ants.ANTsImage
4081
+ Mean diffusivity.
4082
+ RGB_img : ants.ANTsImage
4083
+ RGB FA image.
4084
+ """
4085
+ import numpy as np
4086
+ import ants
4087
+ import dipy.reconst.dti as dti
4088
+ from dipy.core.gradients import gradient_table
4089
+ from dipy.reconst.dti import fractional_anisotropy, color_fa, mean_diffusivity
4090
+ from concurrent.futures import ThreadPoolExecutor
4091
+ from tqdm import tqdm
4092
+
4093
+ model_params = model_params or {}
4094
+ img = imagein.numpy()
4095
+ mask = maskin.numpy().astype(bool)
4096
+ X, Y, Z, N = img.shape
4097
+
4098
+ if bvals_to_use is not None:
4099
+ sel = np.isin(bvals, bvals_to_use)
4100
+ img = img[..., sel]
4101
+ bvals = bvals[sel]
4102
+ bvecs_5d = bvecs_5d[..., sel, :]
4103
+
4104
+ FA = np.zeros((X, Y, Z), dtype=np.float32)
4105
+ MD = np.zeros((X, Y, Z), dtype=np.float32)
4106
+ RGB = np.zeros((X, Y, Z, 3), dtype=np.float32)
4107
+
4108
+ def fit_voxel(ix, iy, iz):
4109
+ if not mask[ix, iy, iz]:
4110
+ return
4111
+ sig = img[ix, iy, iz, :]
4112
+ if np.all(sig == 0):
4113
+ return
4114
+ bv = bvecs_5d[ix, iy, iz, :, :]
4115
+ gtab = gradient_table(bvals, bv)
4116
+ try:
4117
+ model = dti.TensorModel(gtab, **model_params)
4118
+ fit = model.fit(sig)
4119
+ evals = fit.evals
4120
+ evecs = fit.evecs
4121
+ FA[ix, iy, iz] = fractional_anisotropy(evals)
4122
+ MD[ix, iy, iz] = mean_diffusivity(evals)
4123
+ RGB[ix, iy, iz, :] = color_fa(FA[ix, iy, iz], evecs)
4124
+ except Exception as e:
4125
+ if verbose:
4126
+ print(f"Voxel ({ix},{iy},{iz}) fit failed: {e}")
4127
+
4128
+ coords = np.argwhere(mask)
4129
+ if verbose:
4130
+ print(f"[INFO] Fitting {len(coords)} voxels using {num_threads} threads...")
4131
+
4132
+ if num_threads > 1:
4133
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
4134
+ list(tqdm(executor.map(lambda c: fit_voxel(*c), coords), total=len(coords)))
4135
+ else:
4136
+ for c in tqdm(coords):
4137
+ fit_voxel(*c)
4138
+
4139
+ ref = ants.slice_image(imagein, axis=3, idx=0)
4140
+ return (
4141
+ ants.copy_image_info(ref, ants.from_numpy(FA)),
4142
+ ants.copy_image_info(ref, ants.from_numpy(MD)),
4143
+ ants.merge_channels([ants.copy_image_info(ref, ants.from_numpy(RGB[..., i])) for i in range(3)])
4144
+ )
4145
+
4146
+
4147
+ def generate_voxelwise_bvecs(global_bvecs, voxel_rotations, transpose=False):
4148
+ """
4149
+ Generate voxel-wise b-vectors from a global bvec and voxel-wise rotation field.
4150
+
4151
+ Parameters
4152
+ ----------
4153
+ global_bvecs : ndarray of shape (N, 3)
4154
+ Global diffusion gradient directions.
4155
+ voxel_rotations : ndarray of shape (X, Y, Z, 3, 3)
4156
+ 3x3 rotation matrix for each voxel (can come from Jacobian of deformation field).
4157
+ transpose : bool, optional
4158
+ If True, transpose the rotation matrices before applying them to the b-vectors.
4159
+
4160
+
4161
+ Returns
4162
+ -------
4163
+ bvecs_5d : ndarray of shape (X, Y, Z, N, 3)
4164
+ Voxel-specific b-vectors.
4165
+ """
4166
+ X, Y, Z, _, _ = voxel_rotations.shape
4167
+ N = global_bvecs.shape[0]
4168
+ bvecs_5d = np.zeros((X, Y, Z, N, 3), dtype=np.float32)
4169
+
4170
+ for n in range(N):
4171
+ bvec = global_bvecs[n]
4172
+ for i in range(X):
4173
+ for j in range(Y):
4174
+ for k in range(Z):
4175
+ R = voxel_rotations[i, j, k]
4176
+ if transpose:
4177
+ R = R.T # Use transpose if needed
4178
+ bvecs_5d[i, j, k, n, :] = R @ bvec
4179
+ return bvecs_5d
4180
+
3935
4181
  def dipy_dti_recon(
3936
4182
  image,
3937
4183
  bvalsfn,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: antspymm
3
- Version: 1.5.5
3
+ Version: 1.5.7
4
4
  Summary: multi-channel/time-series medical image processing with antspyx
5
5
  Author-email: "Avants, Gosselin, Tustison, Reardon" <stnava@gmail.com>
6
6
  License: Apache-2.0
@@ -24,7 +24,9 @@ docs/deepnbm.jpg
24
24
  docs/deformation_gradient_reo.py
25
25
  docs/describe_mm_data.R
26
26
  docs/dipy_dti_recon.py
27
+ docs/dti_distortion_correction_voxelwise_varying_bvectors_example.py
27
28
  docs/dti_recon.py
29
+ docs/dti_reconstruction_voxelwise_varying_bvectors_localint.py
28
30
  docs/dti_reg.py
29
31
  docs/dwi_rebasing.py
30
32
  docs/dwi_run.py
@@ -62,4 +64,5 @@ docs/ukbb_to_nrg_processing2.py
62
64
  docs/visualize_tractogram.py
63
65
  tests/test_loop.py
64
66
  tests/test_nrg_validation.py
65
- tests/test_reference_run.py
67
+ tests/test_reference_run.py
68
+ tests/voxelwise_bvec_dti_recon_test.py
@@ -0,0 +1,353 @@
1
+ import numpy as np
2
+ import ants
3
+ import os
4
+ from dipy.io.gradients import read_bvals_bvecs
5
+ from scipy.stats import pearsonr
6
+ import antspymm
7
+ import numpy as np
8
+ from scipy.stats import pearsonr
9
+ nt=8
10
+
11
+ ##################################################################
12
+ # for easier to access data with a full mm_csv example, see:
13
+ # github.com:stnava/ANTPD_antspymm
14
+ ##################################################################
15
+ from os.path import exists
16
+ import os
17
+ import signal
18
+ import urllib.request
19
+ import zipfile
20
+ import tempfile
21
+ from pathlib import Path
22
+ from tqdm import tqdm
23
+ import antspynet
24
+
25
+ REQUIRED_FILES = [
26
+ "PPMI/101018/20210412/T1w/1496225/PPMI-101018-20210412-T1w-1496225.nii.gz",
27
+ "PPMI/101018/20210412/DTI_LR/1496234/PPMI-101018-20210412-DTI_LR-1496234.nii.gz"
28
+ ]
29
+
30
+ def broadcast_bvecs_voxelwise(rotated_bvecs, shape):
31
+ return np.broadcast_to(rotated_bvecs, shape + rotated_bvecs.shape).copy()
32
+
33
+
34
+ def _validate_required_files(base_dir, required_files):
35
+ for rel_path in required_files:
36
+ full_path = os.path.join(base_dir, rel_path)
37
+ if not os.path.isfile(full_path):
38
+ print(f"❌ Missing required file: {rel_path}")
39
+ return False
40
+ return True
41
+
42
+ def _download_with_progress(url, destination):
43
+ with urllib.request.urlopen(url) as response, open(destination, 'wb') as out_file:
44
+ total = int(response.getheader('Content-Length', 0))
45
+ with tqdm(total=total, unit='B', unit_scale=True, desc="Downloading", ncols=80) as pbar:
46
+ while True:
47
+ chunk = response.read(8192)
48
+ if not chunk:
49
+ break
50
+ out_file.write(chunk)
51
+ pbar.update(len(chunk))
52
+
53
+ def find_data_dir(candidate_paths=None, max_tries=5, timeout=22, allow_download=None, required_files=REQUIRED_FILES):
54
+ """
55
+ Attempts to locate or download the ANTsPyMM testing dataset.
56
+
57
+ Parameters
58
+ ----------
59
+ candidate_paths : list of str or None
60
+ Directories to search for the data. If None, uses sensible defaults.
61
+ max_tries : int
62
+ Number of chances to enter a valid path manually.
63
+ timeout : int
64
+ Seconds to wait for user input before timing out.
65
+ allow_download : None | str
66
+ If not None, will download to {allow_download}/nrgdata_test if needed.
67
+ required_files : list of str
68
+ Relative paths that must exist inside the data directory.
69
+
70
+ Returns
71
+ -------
72
+ str
73
+ Path to a valid data directory.
74
+ """
75
+ if candidate_paths is None:
76
+ candidate_paths = [
77
+ "~/Downloads/temp/shortrun/nrgdata_test",
78
+ "~/Downloads/ANTsPyMM_testing_data/nrgdata_test",
79
+ "~/data/ppmi/nrgdata_test",
80
+ "/mnt/data/nrgdata_test"
81
+ ]
82
+
83
+ # First, search known paths
84
+ for path in candidate_paths:
85
+ full_path = os.path.expanduser(path)
86
+ if os.path.isdir(full_path) and _validate_required_files(full_path, required_files):
87
+ print(f"✅ Found valid data directory: {full_path}")
88
+ return full_path
89
+
90
+ # Handle automatic download
91
+ if isinstance(allow_download, str):
92
+ base_dir = os.path.expanduser(allow_download)
93
+ target_dir = os.path.join(base_dir, "nrgdata_test")
94
+ if not os.path.isdir(target_dir) or not _validate_required_files(target_dir, required_files):
95
+ print(f"📥 Will download data to: {target_dir}")
96
+ url = "https://figshare.com/ndownloader/articles/29391236/versions/1"
97
+ os.makedirs(base_dir, exist_ok=True)
98
+ zip_path = os.path.join(tempfile.gettempdir(), "antspymm_testdata.zip")
99
+
100
+ try:
101
+ _download_with_progress(url, zip_path)
102
+ print("📦 Extracting...")
103
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
104
+ zip_ref.extractall(base_dir)
105
+ print(f"✅ Extracted to {target_dir}")
106
+ except Exception as e:
107
+ raise RuntimeError(f"❌ Download or extraction failed: {e}")
108
+
109
+ if not _validate_required_files(target_dir, required_files):
110
+ raise RuntimeError(f"❌ Downloaded data is missing required files in {target_dir}")
111
+ return target_dir
112
+
113
+ # Timeout handler for POSIX
114
+ def timeout_handler(signum, frame):
115
+ raise TimeoutError("⏳ No input received in time.")
116
+
117
+ if os.name == 'posix':
118
+ signal.signal(signal.SIGALRM, timeout_handler)
119
+
120
+ # Manual user prompt
121
+ print("🔍 Could not find valid data. You may enter a directory manually.")
122
+ print("🔗 Dataset info: https://figshare.com/articles/dataset/ANTsPyMM_testing_data/29391236")
123
+
124
+ for attempt in range(1, max_tries + 1):
125
+ try:
126
+ if os.name == 'posix':
127
+ signal.alarm(timeout)
128
+ user_input = input(f"⏱️ Attempt {attempt}/{max_tries} — Enter data directory (or 'q' to quit): ").strip()
129
+ if os.name == 'posix':
130
+ signal.alarm(0)
131
+
132
+ if user_input.lower() == 'q':
133
+ break
134
+
135
+ path = os.path.expanduser(user_input)
136
+ if os.path.isdir(path) and _validate_required_files(path, required_files):
137
+ print(f"✅ Using user-provided directory: {path}")
138
+ return path
139
+ else:
140
+ print("❌ Invalid or incomplete directory.")
141
+
142
+ except TimeoutError as e:
143
+ raise RuntimeError(str(e))
144
+ except KeyboardInterrupt:
145
+ raise RuntimeError("User interrupted execution. Exiting.")
146
+
147
+ raise RuntimeError("❗ No valid data directory found and download not permitted.")
148
+
149
+ candidate_rdirs = [
150
+ "~/Downloads/nrgdata_test/",
151
+ "~/Downloads/temp/nrgdata_test/",
152
+ "~/nrgdata_test/",
153
+ "~/data/ppmi/nrgdata_test/",
154
+ "/mnt/data/ppmi_testing/nrgdata_test/"]
155
+
156
+
157
+ rdir = find_data_dir( candidate_rdirs, allow_download="~/Downloads" )
158
+ print(f"Using data directory: {rdir}")
159
+
160
+ nthreads = str(8)
161
+ os.environ["TF_NUM_INTEROP_THREADS"] = nthreads
162
+ os.environ["TF_NUM_INTRAOP_THREADS"] = nthreads
163
+ os.environ["ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS"] = nthreads
164
+ os.environ["OPENBLAS_NUM_THREADS"] = nthreads
165
+ os.environ["MKL_NUM_THREADS"] = nthreads
166
+ import numpy as np
167
+ import glob as glob
168
+ import antspymm
169
+ import ants
170
+ import random
171
+ import re
172
+
173
+ def read_bvecs_rotated(bvec_file, rotmat):
174
+ bvecs = np.loadtxt(bvec_file)
175
+ if bvecs.shape[0] != 3:
176
+ bvecs = bvecs.T
177
+ rotated_bvecs = (rotmat @ bvecs).T
178
+ return rotated_bvecs
179
+
180
+ def mean_rgb_correlation(img1, img2, mask):
181
+ """
182
+ Compute the mean correlation between two RGB images.
183
+
184
+ Parameters
185
+ ----------
186
+ img1 : np.ndarray
187
+ First RGB image as a (H, W, 3) NumPy array.
188
+ img2 : np.ndarray
189
+ Second RGB image as a (H, W, 3) NumPy array.
190
+
191
+ Returns
192
+ -------
193
+ float
194
+ Mean Pearson correlation across the three RGB channels.
195
+ """
196
+ if img1.shape != img2.shape:
197
+ raise ValueError("Input images must have the same shape.")
198
+ correlations = []
199
+ img1c = ants.split_channels(img1)
200
+ img2c = ants.split_channels(img2)
201
+ for c in range(3): # R, G, B
202
+ x = extract_masked_values( img1c[c], mask)
203
+ y = extract_masked_values( img2c[c], mask)
204
+ if np.std(x) == 0 or np.std(y) == 0:
205
+ corr = 0.0 # Handle flat images
206
+ else:
207
+ corr, _ = pearsonr(x, y)
208
+ correlations.append(corr)
209
+ return np.mean(correlations)
210
+
211
+ import numpy as np
212
+ import ants
213
+
214
+ def mean_rgb_mae(img1, img2, mask):
215
+ """
216
+ Compute the mean absolute error (MAE) between two RGB images.
217
+
218
+ Parameters
219
+ ----------
220
+ img1 : np.ndarray
221
+ First RGB image as a (H, W, 3) NumPy array.
222
+ img2 : np.ndarray
223
+ Second RGB image as a (H, W, 3) NumPy array.
224
+ mask : ants.ANTsImage
225
+ Binary mask defining valid pixels for error calculation.
226
+
227
+ Returns
228
+ -------
229
+ float
230
+ Mean absolute error across the three RGB channels.
231
+ """
232
+ if img1.shape != img2.shape:
233
+ raise ValueError("Input images must have the same shape.")
234
+
235
+ mae_values = []
236
+ img1c = ants.split_channels(img1)
237
+ img2c = ants.split_channels(img2)
238
+
239
+ for c in range(3): # R, G, B
240
+ x = extract_masked_values(img1c[c], mask)
241
+ y = extract_masked_values(img2c[c], mask)
242
+ mae = np.mean(np.abs(x - y))
243
+ mae_values.append(mae)
244
+
245
+ return np.mean(mae_values)
246
+
247
+ def extract_masked_values(image, mask):
248
+ return image.numpy()[mask.numpy() > 0]
249
+
250
+ import numpy as np
251
+
252
+ def verify_unit_bvecs(bvecs, tol=1e-5):
253
+ """
254
+ Verifies that each b-vector has unit norm within a tolerance.
255
+
256
+ Parameters
257
+ ----------
258
+ bvecs : array-like, shape (N, 3)
259
+ Array of b-vectors, one per diffusion direction.
260
+ tol : float
261
+ Tolerance for unit norm.
262
+
263
+ Returns
264
+ -------
265
+ is_unit : np.ndarray, shape (N,)
266
+ Boolean array indicating if each b-vector is unit norm.
267
+ norms : np.ndarray, shape (N,)
268
+ Norms of each b-vector.
269
+ """
270
+ bvecs = np.asarray(bvecs)
271
+ norms = np.linalg.norm(bvecs, axis=1)
272
+ is_unit = np.abs(norms - 1) < tol
273
+ return is_unit, norms
274
+
275
+ mydir = rdir + "PPMI/"
276
+ outdir = re.sub( 'nrgdata_test', 'antspymmoutput', rdir )
277
+ import glob as glob
278
+
279
+ t1fn=glob.glob(mydir+"101018/20210412/T1w/1496225/*.nii.gz")
280
+ if len(t1fn) > 0:
281
+ t1fn=t1fn[0]
282
+ print("Begin " + t1fn)
283
+ dtfn=glob.glob(mydir+"101018/20210412/DTI*/*/*.nii.gz")
284
+ dtfn.sort()
285
+
286
+ import re
287
+
288
+ # def test_efficient_dwi_fit_voxelwise_distortion_correction():
289
+ if len(dtfn) > 0:
290
+ img_LR_in = ants.image_read(dtfn[0])
291
+ img_LR_in_avg = ants.get_average_of_timeseries( img_LR_in )
292
+ mask = img_LR_in_avg.get_mask()
293
+ bvalfn = re.sub( 'nii.gz', 'bval', dtfn[0] )
294
+ bvecfn = re.sub( 'nii.gz', 'bvec', dtfn[0] )
295
+ if not exists(bvalfn) or not exists(bvecfn):
296
+ raise RuntimeError(f"Required bval/bvec files not found: {bvalfn}, {bvecfn}")
297
+ print(f"📁 Loading subject LR data from {bvalfn} ")
298
+ bvals, bvecs = read_bvals_bvecs(bvalfn, bvecfn)
299
+ bvecs = np.asarray(bvecs)
300
+ shape = img_LR_in.shape[:3]
301
+
302
+ print("📁 Loading subject T1 data...")
303
+ t1w = ants.image_read(t1fn)
304
+ t1w = ants.resample_image(t1w, [2, 2, 2], use_voxels=False)
305
+ bxt = antspynet.brain_extraction(t1w, modality='t1', verbose=False).threshold_image(0.5, 1.5)
306
+
307
+ if not "mytx" in globals():
308
+ dwianat = ants.slice_image( img_LR_in, idx=0, axis=3)
309
+ mytx = ants.registration( t1w, dwianat, 'SyNCC', syn_metric='CC', syn_sampling=2, total_sigma=0.5 )
310
+ mytx2 = ants.apply_transforms(t1w, img_LR_in_avg, mytx['fwdtransforms'],
311
+ interpolator='linear', imagetype=0, compose='/tmp/comptxDT2T1' )
312
+ print( mytx2 )
313
+
314
+ print("🔄 now with distortion correction...")
315
+ mydef = ants.image_read( mytx2 )
316
+ mywarp = ants.transform_from_displacement_field( mydef )
317
+ img_w = antspymm.timeseries_transform(mywarp, img_LR_in, reference=t1w)
318
+ print("🧠 Running warped fit...")
319
+
320
+ if not "FA_w" in globals():
321
+ mydefgrad = antspymm.deformation_gradient_optimized( mydef,
322
+ to_rotation=False, to_inverse_rotation=True )
323
+ bvecsdc = antspymm.distortion_correct_bvecs( bvecs, mydefgrad, t1w.direction, img_LR_in_avg.direction )
324
+ FA_w, MD_w, RGB_w = antspymm.efficient_dwi_fit_voxelwise(
325
+ imagein=img_w,
326
+ maskin=bxt,
327
+ bvals=bvals,
328
+ bvecs_5d=bvecsdc,
329
+ model_params={},
330
+ bvals_to_use=None,
331
+ num_threads=nt,
332
+ verbose=False
333
+ )
334
+
335
+ if not "FA_w2" in globals():
336
+ FA_w2, MD_w2, RGB_w2 = antspymm.efficient_dwi_fit_voxelwise(
337
+ imagein=img_w,
338
+ maskin=bxt,
339
+ bvals=bvals,
340
+ bvecs_5d=broadcast_bvecs_voxelwise(bvecs, t1w.shape),
341
+ model_params={},
342
+ bvals_to_use=None,
343
+ num_threads=nt,
344
+ verbose=False
345
+ )
346
+
347
+ print("📊 Comparing results...")
348
+ maske=ants.iMath(bxt,'ME',3)
349
+ fa_corr = mean_rgb_correlation( RGB_w, RGB_w2, maske )
350
+ print(f"✅ Direction-weighted FA correlation (original vs distortion corrected): {fa_corr:.4f}")
351
+
352
+ ants.image_write( t1w, '/tmp/t1w.nii.gz' )
353
+ ants.image_write( RGB_w, '/tmp/rgbw.nii.gz' )
@@ -0,0 +1,282 @@
1
+ import numpy as np
2
+ import ants
3
+ import os
4
+ from dipy.io.gradients import read_bvals_bvecs
5
+ from scipy.stats import pearsonr
6
+ import antspymm
7
+ import matplotlib.pyplot as plt
8
+ nt = 2
9
+ # amount or rotation around x, y, z-axis
10
+ degrotx = 15
11
+ degroty = 15
12
+ degrotz = 15
13
+
14
+
15
+ def mean_rgb_correlation(img1, img2, mask):
16
+ """
17
+ Compute the mean correlation between two RGB images.
18
+
19
+ Parameters
20
+ ----------
21
+ img1 : np.ndarray
22
+ First RGB image as a (H, W, 3) NumPy array.
23
+ img2 : np.ndarray
24
+ Second RGB image as a (H, W, 3) NumPy array.
25
+
26
+ Returns
27
+ -------
28
+ float
29
+ Mean Pearson correlation across the three RGB channels.
30
+ """
31
+ if img1.shape != img2.shape:
32
+ raise ValueError("Input images must have the same shape.")
33
+ correlations = []
34
+ img1c = ants.split_channels(img1)
35
+ img2c = ants.split_channels(img2)
36
+ for c in range(3): # R, G, B
37
+ x = extract_masked_values( img1c[c], mask)
38
+ y = extract_masked_values( img2c[c], mask)
39
+ if np.std(x) == 0 or np.std(y) == 0:
40
+ corr = 0.0 # Handle flat images
41
+ else:
42
+ corr, _ = pearsonr(x, y)
43
+ correlations.append(corr)
44
+ return np.mean(correlations)
45
+
46
+
47
+ def plot_correlation(x, y, xlabel="Image 1", ylabel="Image 2", title_prefix="Correlation", point_alpha=0.3):
48
+ """
49
+ Plot the correlation between two 1D arrays and display Pearson r.
50
+
51
+ Parameters
52
+ ----------
53
+ x : array-like
54
+ First set of values.
55
+ y : array-like
56
+ Second set of values.
57
+ xlabel : str
58
+ Label for the x-axis.
59
+ ylabel : str
60
+ Label for the y-axis.
61
+ title_prefix : str
62
+ Title prefix, will append r value.
63
+ point_alpha : float
64
+ Transparency of scatter plot points.
65
+
66
+ Returns
67
+ -------
68
+ float
69
+ Pearson correlation coefficient (r).
70
+ """
71
+ x = np.asarray(x)
72
+ y = np.asarray(y)
73
+
74
+ if x.shape != y.shape:
75
+ raise ValueError("Inputs must have the same shape.")
76
+
77
+ r, _ = pearsonr(x, y)
78
+
79
+ plt.figure(figsize=(6, 6))
80
+ plt.scatter(x, y, s=1, alpha=point_alpha, color='blue')
81
+ plt.xlabel(xlabel)
82
+ plt.ylabel(ylabel)
83
+ plt.title(f"{title_prefix} (r = {r:.3f})")
84
+ plt.plot([x.min(), x.max()], [x.min(), x.max()], 'r--', linewidth=1) # Identity line
85
+ plt.grid(True)
86
+ plt.axis('equal')
87
+ plt.tight_layout()
88
+ plt.show()
89
+
90
+ return r
91
+
92
+ def read_bvecs_rotated(bvec_file, rotmat):
93
+ bvecs = np.loadtxt(bvec_file)
94
+ if bvecs.shape[0] != 3:
95
+ bvecs = bvecs.T
96
+ rotated_bvecs = (rotmat @ bvecs).T
97
+ return rotated_bvecs
98
+
99
+ def broadcast_bvecs_voxelwise(rotated_bvecs, shape):
100
+ return np.broadcast_to(rotated_bvecs, shape + rotated_bvecs.shape).copy()
101
+
102
+ def extract_masked_values(image, mask):
103
+ return image.numpy()[mask.numpy() > 0]
104
+
105
+ # def test_efficient_dwi_fit_voxelwise_rotation_consistency():
106
+ if True:
107
+ print("simple test for rotation consistency of efficient_dwi_fit_voxelwise")
108
+ ex_path = os.path.expanduser( "~/.antspyt1w/" )
109
+ ex_path_mm = os.path.expanduser( "~/.antspymm/" )
110
+ JHU_atlas = ants.image_read( ex_path + 'JHU-ICBM-FA-1mm.nii.gz' ) # Read in JHU atlas
111
+ JHU_labels = ants.image_read( ex_path + 'JHU-ICBM-labels-1mm.nii.gz' ) # Read in JHU labels
112
+ #### Load in data ####
113
+ print("Load in subject data ...")
114
+ lrid = ex_path_mm + "I1499279_Anon_20210819142214_5"
115
+ rlid = ex_path_mm + "I1499337_Anon_20210819142214_6"
116
+ t1id = ex_path_mm + "t1_rand.nii.gz"
117
+ # Load paths
118
+ print("📁 Loading subject data...")
119
+ lrid = os.path.join(ex_path_mm, "I1499279_Anon_20210819142214_5")
120
+ img_LR_in = ants.image_read(lrid + '.nii.gz')
121
+ img_LR_in_avg = ants.get_average_of_timeseries( img_LR_in )
122
+ mask = img_LR_in_avg.get_mask()
123
+
124
+ bvals, bvecs = read_bvals_bvecs(lrid + '.bval', lrid + '.bvec')
125
+ bvecs = np.asarray(bvecs)
126
+ shape = img_LR_in.shape[:3]
127
+
128
+ print("🧠 Running baseline (unrotated) fit...")
129
+ bvecs_5d_orig = np.broadcast_to(bvecs, shape + bvecs.shape).copy()
130
+ if not "FA_orig" in globals():
131
+ FA_orig, MD_orig, RGB_orig = antspymm.efficient_dwi_fit_voxelwise(
132
+ imagein=img_LR_in,
133
+ maskin=mask,
134
+ bvals=bvals,
135
+ bvecs_5d=bvecs_5d_orig,
136
+ model_params={},
137
+ bvals_to_use=None,
138
+ num_threads=nt,
139
+ verbose=False
140
+ )
141
+
142
+ print("🔄 Applying known rotation...")
143
+ # maxtrans = 10.0
144
+ # rotator = ants.contrib.RandomRotate3D((-maxtrans, maxtrans), reference=img_LR_in_avg)
145
+ rotator=ants.contrib.Rotate3D(rotation=(degrotx,degroty,degrotz), reference=img_LR_in_avg )
146
+ rotation = rotator.transform()
147
+ img_rotated_ref = ants.apply_ants_transform_to_image(rotation, img_LR_in_avg, reference=img_LR_in_avg)
148
+ img_rotated = antspymm.timeseries_transform(rotation, img_LR_in, reference=img_LR_in_avg)
149
+ mask_rotated = ants.apply_ants_transform_to_image(rotation, mask, reference=img_LR_in_avg, interpolation='nearestNeighbor')
150
+ # note: we apply the inverse rotation to the bvecs
151
+ # i.e. if we register A to B and get R then apply R_inv to bvecs
152
+ rotmat = ants.get_ants_transform_parameters(rotation.invert()).reshape((4, 3))[:3, :3]
153
+ bvecs_rotated = read_bvecs_rotated(lrid + '.bvec', rotmat)
154
+ bvecs_5d_rot = broadcast_bvecs_voxelwise(bvecs_rotated, shape)
155
+
156
+ print("🧠 Running rotated fit...")
157
+ if not "FA_rot" in globals():
158
+ FA_rot, MD_rot, RGB_rot = antspymm.efficient_dwi_fit_voxelwise(
159
+ imagein=img_rotated,
160
+ maskin=mask_rotated,
161
+ bvals=bvals,
162
+ bvecs_5d=bvecs_5d_rot,
163
+ model_params={},
164
+ bvals_to_use=None,
165
+ num_threads=nt,
166
+ verbose=False
167
+ )
168
+
169
+ FA_rot_back = ants.apply_ants_transform_to_image(
170
+ rotation.invert(),
171
+ FA_rot,
172
+ reference = img_LR_in_avg,
173
+ )
174
+ MD_rot_back = ants.apply_ants_transform_to_image(
175
+ rotation.invert(),
176
+ MD_rot,
177
+ reference = img_LR_in_avg,
178
+ )
179
+ print("📊 Comparing results...")
180
+ maske=ants.iMath(mask,'ME',2)
181
+ # smoothing simulates double interpolation
182
+ FA_origs = ants.smooth_image(FA_orig, 1.25)
183
+ MD_origs = ants.smooth_image(MD_orig, 1.25)
184
+ fa1 = extract_masked_values(FA_origs, maske)
185
+ fa2 = extract_masked_values(FA_rot_back, maske)
186
+ md1 = extract_masked_values(MD_origs, maske)
187
+ md2 = extract_masked_values(MD_rot_back, maske)
188
+
189
+ fa_corr, _ = pearsonr(fa1, fa2)
190
+ md_corr, _ = pearsonr(md1, md2)
191
+
192
+ # plot_correlation( fa1, fa2 )
193
+
194
+ print(f"✅ FA correlation (original vs rotated): {fa_corr:.4f}")
195
+ print(f"✅ MD correlation (original vs rotated): {md_corr:.4f}")
196
+
197
+ assert fa_corr > 0.80, "FA correlation too low"
198
+ assert md_corr > 0.80, "MD correlation too low"
199
+
200
+ print("🎉 Test passed: model is rotation-consistent with voxelwise bvecs.")
201
+
202
+ print("This shows that the simulation is effective.")
203
+ print("Now use the simulated data to test the distortion correction...")
204
+ # ants.image_write( RGB_orig, '/tmp/temp0rgb.nii.gz' )
205
+ # ants.image_write( RGB_rot, '/tmp/temp1rgb.nii.gz' )
206
+
207
+
208
+ # now we have to map the img_rotated and its bvec_rotated partner
209
+ # as we would with a generic distortion correction framework.
210
+ # first --- write the transform to a file
211
+ rotationinv = rotation.invert()
212
+
213
+ import tempfile
214
+ with tempfile.TemporaryDirectory() as tempdir:
215
+ rotmatfile = os.path.join(tempdir, "rotation.mat")
216
+ compositefile = os.path.join(tempdir, "composite")
217
+
218
+ # Write the rotation transform
219
+ ants.write_transform(rotationinv, rotmatfile)
220
+
221
+ # Run registration if not already done
222
+ if "reg" not in globals():
223
+ reg = ants.registration(FA_orig, FA_rot, 'SyN', initial_transform=rotmatfile)
224
+
225
+ # Apply the transformation using a temporary composite path
226
+ comptx = ants.apply_transforms(
227
+ img_LR_in_avg,
228
+ img_rotated_ref,
229
+ reg['fwdtransforms'],
230
+ interpolator='linear',
231
+ compose=compositefile,
232
+ verbose=True
233
+ )
234
+ mydef=ants.image_read(comptx)
235
+ mydefgrad = antspymm.deformation_gradient_optimized( mydef,
236
+ to_rotation=False, to_inverse_rotation=True )
237
+ # bvecsRLw = antspymm.generate_voxelwise_bvecs( bvecs_rotated, mydefgrad, transpose=False )
238
+ img_rotated_avg = ants.get_average_of_timeseries( img_rotated )
239
+ bvecsRLw = antspymm.distortion_correct_bvecs( bvecs_rotated, mydefgrad, img_LR_in_avg.direction, img_rotated_avg.direction )
240
+ mywarp = ants.transform_from_displacement_field( mydef )
241
+ img_w = antspymm.timeseries_transform(mywarp, img_rotated, reference=img_rotated_avg )
242
+
243
+ correlations = []
244
+ labels = ["NoBvecReo", "BvecReo"]
245
+
246
+ RGB_origs = ants.smooth_image(RGB_orig, 1.0)
247
+
248
+ for label, bv in zip(labels, [bvecs_5d_rot, bvecsRLw]):
249
+ FA_w, MD_w, RGB_w = antspymm.efficient_dwi_fit_voxelwise(
250
+ imagein=img_w,
251
+ maskin=ants.get_mask(ants.get_average_of_timeseries(img_w)),
252
+ bvals=bvals,
253
+ bvecs_5d=bv,
254
+ model_params={},
255
+ bvals_to_use=None,
256
+ num_threads=nt,
257
+ verbose=False
258
+ )
259
+ # Optional: write RGB images to temporary files for inspection
260
+ rgb_orig_path = os.path.join(tempdir, "RGB_orig.nii.gz")
261
+ rgb_warped_path = os.path.join(tempdir, f"RGB_warped_{label}.nii.gz")
262
+
263
+ ants.image_write(RGB_origs, rgb_orig_path)
264
+ ants.image_write(RGB_w, rgb_warped_path)
265
+
266
+ print(f"📝 Saved RGB_orig to: {rgb_orig_path}")
267
+ print(f"📝 Saved RGB_warped ({label}) to: {rgb_warped_path}")
268
+
269
+ compmask = ants.get_mask(ants.get_average_of_timeseries(img_w))
270
+ compmask = ants.iMath(compmask, 'ME', 2)
271
+
272
+ corr = mean_rgb_correlation(RGB_origs, RGB_w, compmask)
273
+ print(f"{label:10s} correlation: {corr:.4f}")
274
+ correlations.append(corr)
275
+
276
+ # Compare the two correlations
277
+ if len(correlations) == 2:
278
+ print("\nComparison Result:")
279
+ if correlations[1] > correlations[0]:
280
+ print("✅ BvecReo gives higher correlation than NoBvecReo.")
281
+ else:
282
+ print("❌ NoBvecReo gives equal or higher correlation than BvecReo.")
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "antspymm"
7
- version = "1.5.5"
7
+ version = "1.5.7"
8
8
  description = "multi-channel/time-series medical image processing with antspyx"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.9"
@@ -0,0 +1,118 @@
1
+ import numpy as np
2
+ from scipy.spatial.transform import Rotation as R
3
+ import antspymm
4
+
5
+ def test_generate_voxelwise_bvecs():
6
+ # Define synthetic bvecs (N = 3 directions)
7
+ bvecs = np.array([
8
+ [1, 0, 0], # x-axis
9
+ [0, 1, 0], # y-axis
10
+ [0, 0, 1], # z-axis
11
+ ])
12
+
13
+ # Define image shape (small synthetic 4D volume)
14
+ X, Y, Z = 2, 2, 2
15
+ N = bvecs.shape[0]
16
+
17
+ # Define a known rotation (e.g., 90-degree rotation around z-axis)
18
+ rot = R.from_euler('z', 33, degrees=True).as_matrix() # shape (3, 3)
19
+
20
+ # Create a voxelwise rotation field (X, Y, Z, 3, 3), same rot everywhere
21
+ voxel_rotations = np.tile(rot, (X, Y, Z, 1, 1))
22
+ voxel_rotations = voxel_rotations.reshape(X, Y, Z, 3, 3)
23
+
24
+ # Expected rotated bvecs
25
+ expected = np.dot(voxel_rotations[0, 0, 0], bvecs.T).T # (N, 3)
26
+
27
+ # Call the function under test
28
+ bvecs_5d = antspymm.generate_voxelwise_bvecs(bvecs, voxel_rotations) # shape (X, Y, Z, N, 3)
29
+
30
+ # Check that all voxel outputs match expected result
31
+ for i in range(X):
32
+ for j in range(Y):
33
+ for k in range(Z):
34
+ actual = bvecs_5d[i, j, k]
35
+ assert np.allclose(actual, expected, atol=1e-6), \
36
+ f"Mismatch at voxel {(i, j, k)}: {actual} vs {expected}"
37
+
38
+ print("✅ test_generate_voxelwise_bvecs passed!")
39
+
40
+ # Run the test
41
+ print("Running test for generate_voxelwise_bvecs...")
42
+ test_generate_voxelwise_bvecs()
43
+
44
+
45
+ import numpy as np
46
+ import ants
47
+ from scipy.spatial.transform import Rotation as R
48
+
49
+ def generate_dummy_dwi_data(shape, n_volumes):
50
+ """
51
+ Generate synthetic DWI-like 4D data and a brain mask.
52
+ """
53
+ np.random.seed(42)
54
+ dwi = np.random.rand(*shape, n_volumes).astype(np.float32)
55
+ mask = np.ones(shape, dtype=np.uint8)
56
+ return dwi, mask
57
+
58
+ def create_voxelwise_bvecs(shape, bvecs, rotation_matrix=None):
59
+ """
60
+ Create voxelwise (5D) bvecs array (X, Y, Z, N, 3), optionally rotated.
61
+ """
62
+ if rotation_matrix is not None:
63
+ rotated_bvecs = (rotation_matrix @ bvecs.T).T
64
+ else:
65
+ rotated_bvecs = bvecs
66
+
67
+ # Broadcast to 5D shape
68
+ bvecs_5d = np.broadcast_to(rotated_bvecs, shape + rotated_bvecs.shape)
69
+ return bvecs_5d.copy()
70
+
71
+ def test_efficient_dwi_fit_voxelwise():
72
+
73
+ # Parameters
74
+ shape = (3, 3, 3)
75
+ n_vols = 6
76
+
77
+ # Synthetic data
78
+ dwi_data, mask_data = generate_dummy_dwi_data(shape, n_vols)
79
+ ants_dwi = ants.from_numpy(dwi_data)
80
+ ants_mask = ants.from_numpy(mask_data)
81
+
82
+ # Define bvals and bvecs
83
+ bvals = np.array([0, 1000, 1000, 1000, 1000, 1000])
84
+ bvecs = np.array([
85
+ [0, 0, 0],
86
+ [1, 0, 0],
87
+ [0, 1, 0],
88
+ [0, 0, 1],
89
+ [1/np.sqrt(2), 1/np.sqrt(2), 0],
90
+ [1/np.sqrt(2), 0, 1/np.sqrt(2)],
91
+ ])
92
+
93
+ # Optional: apply rotation
94
+ rotation_matrix = R.from_euler('z', 45, degrees=True).as_matrix()
95
+ bvecs_5d = create_voxelwise_bvecs(shape, bvecs, rotation_matrix)
96
+
97
+ # Call function under test
98
+ FA_img, MD_img, RGB_img = antspymm.efficient_dwi_fit_voxelwise(
99
+ imagein=ants_dwi,
100
+ maskin=ants_mask,
101
+ bvals=bvals,
102
+ bvecs_5d=bvecs_5d,
103
+ model_params={},
104
+ bvals_to_use=None,
105
+ num_threads=1,
106
+ verbose=False
107
+ )
108
+
109
+ # Tests
110
+ assert isinstance(FA_img, ants.ANTsImage), "FA_img should be an ANTsImage"
111
+ assert FA_img.shape == shape, f"FA image has shape {FA_img.shape}, expected {shape}"
112
+ assert np.all((FA_img.numpy() >= 0) & (FA_img.numpy() <= 1)), "FA values should be in [0, 1]"
113
+ assert MD_img.shape == shape, "MD image has incorrect shape"
114
+
115
+ print("✅ test_efficient_dwi_fit_voxelwise passed!")
116
+
117
+ # Run the test
118
+ test_efficient_dwi_fit_voxelwise()
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes