antspymm 1.5.4__tar.gz → 1.5.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {antspymm-1.5.4/antspymm.egg-info → antspymm-1.5.6}/PKG-INFO +3 -5
  2. {antspymm-1.5.4 → antspymm-1.5.6}/antspymm/__init__.py +4 -0
  3. {antspymm-1.5.4 → antspymm-1.5.6}/antspymm/mm.py +213 -1
  4. {antspymm-1.5.4 → antspymm-1.5.6/antspymm.egg-info}/PKG-INFO +3 -5
  5. {antspymm-1.5.4 → antspymm-1.5.6}/antspymm.egg-info/SOURCES.txt +4 -2
  6. {antspymm-1.5.4 → antspymm-1.5.6}/antspymm.egg-info/requires.txt +2 -2
  7. antspymm-1.5.6/docs/dti_distortion_correction_voxelwise_varying_bvectors_example_WIP.py +238 -0
  8. antspymm-1.5.6/docs/dti_reconstruction_voxelwise_varying_bvectors_localint.py +281 -0
  9. antspymm-1.5.6/docs/mm_csv_localint.py +206 -0
  10. {antspymm-1.5.4 → antspymm-1.5.6}/pyproject.toml +4 -4
  11. antspymm-1.5.6/tests/voxelwise_bvec_dti_recon_test.py +118 -0
  12. antspymm-1.5.4/LICENSE +0 -201
  13. antspymm-1.5.4/docs/mm_csv_localint.py +0 -52
  14. {antspymm-1.5.4 → antspymm-1.5.6}/MANIFEST.in +0 -0
  15. {antspymm-1.5.4 → antspymm-1.5.6}/README.md +0 -0
  16. {antspymm-1.5.4 → antspymm-1.5.6}/antspymm.egg-info/dependency_links.txt +0 -0
  17. {antspymm-1.5.4 → antspymm-1.5.6}/antspymm.egg-info/top_level.txt +0 -0
  18. {antspymm-1.5.4 → antspymm-1.5.6}/docs/adni_rsfmri_2_nrg_conversion.py +0 -0
  19. {antspymm-1.5.4 → antspymm-1.5.6}/docs/antspymm_annotated_output_tree.pages +0 -0
  20. {antspymm-1.5.4 → antspymm-1.5.6}/docs/antspymm_annotated_output_tree.txt +0 -0
  21. {antspymm-1.5.4 → antspymm-1.5.6}/docs/antspymm_data_dictionary.csv +0 -0
  22. {antspymm-1.5.4 → antspymm-1.5.6}/docs/aslprep_perfusion_run_localint.py +0 -0
  23. {antspymm-1.5.4 → antspymm-1.5.6}/docs/bids_2_nrg.py +0 -0
  24. {antspymm-1.5.4 → antspymm-1.5.6}/docs/bids_cohort_example.py +0 -0
  25. {antspymm-1.5.4 → antspymm-1.5.6}/docs/bind_mm_wide.R +0 -0
  26. {antspymm-1.5.4 → antspymm-1.5.6}/docs/blind_qc.Rmd +0 -0
  27. {antspymm-1.5.4 → antspymm-1.5.6}/docs/blind_qc.html +0 -0
  28. {antspymm-1.5.4 → antspymm-1.5.6}/docs/blind_qc.py +0 -0
  29. {antspymm-1.5.4 → antspymm-1.5.6}/docs/convert_adni_dti_to_nrg.R +0 -0
  30. {antspymm-1.5.4 → antspymm-1.5.6}/docs/deepnbm.jpg +0 -0
  31. {antspymm-1.5.4 → antspymm-1.5.6}/docs/deformation_gradient_reo.py +0 -0
  32. {antspymm-1.5.4 → antspymm-1.5.6}/docs/describe_mm_data.R +0 -0
  33. {antspymm-1.5.4 → antspymm-1.5.6}/docs/dipy_dti_recon.py +0 -0
  34. {antspymm-1.5.4 → antspymm-1.5.6}/docs/dti_recon.py +0 -0
  35. {antspymm-1.5.4 → antspymm-1.5.6}/docs/dti_reg.py +0 -0
  36. {antspymm-1.5.4 → antspymm-1.5.6}/docs/dwi_rebasing.py +0 -0
  37. {antspymm-1.5.4 → antspymm-1.5.6}/docs/dwi_run.py +0 -0
  38. {antspymm-1.5.4 → antspymm-1.5.6}/docs/dwi_run_ptbp_scrub.py +0 -0
  39. {antspymm-1.5.4 → antspymm-1.5.6}/docs/ex_rsfmri_run_minimal_ptbp.py +0 -0
  40. {antspymm-1.5.4 → antspymm-1.5.6}/docs/ex_sr.py +0 -0
  41. {antspymm-1.5.4 → antspymm-1.5.6}/docs/example_antspymm_output.csv +0 -0
  42. {antspymm-1.5.4 → antspymm-1.5.6}/docs/example_run_from_directory.py +0 -0
  43. {antspymm-1.5.4 → antspymm-1.5.6}/docs/flair_run_localint.py +0 -0
  44. {antspymm-1.5.4 → antspymm-1.5.6}/docs/joint_dti_recon_localint.py +0 -0
  45. {antspymm-1.5.4 → antspymm-1.5.6}/docs/make_dict_table.Rmd +0 -0
  46. {antspymm-1.5.4 → antspymm-1.5.6}/docs/make_dict_table.html +0 -0
  47. {antspymm-1.5.4 → antspymm-1.5.6}/docs/mm.py +0 -0
  48. {antspymm-1.5.4 → antspymm-1.5.6}/docs/mm_csv_ex_2.py +0 -0
  49. {antspymm-1.5.4 → antspymm-1.5.6}/docs/mm_nrg.py +0 -0
  50. {antspymm-1.5.4 → antspymm-1.5.6}/docs/nrg_cohort_example.py +0 -0
  51. {antspymm-1.5.4 → antspymm-1.5.6}/docs/parallel_study_aggregation_example.py +0 -0
  52. {antspymm-1.5.4 → antspymm-1.5.6}/docs/perfusion_ptbp.py +0 -0
  53. {antspymm-1.5.4 → antspymm-1.5.6}/docs/perfusion_run_nnl.py +0 -0
  54. {antspymm-1.5.4 → antspymm-1.5.6}/docs/ppmi_step1_blind_qc.py +0 -0
  55. {antspymm-1.5.4 → antspymm-1.5.6}/docs/ppmi_step2_outlierness.py +0 -0
  56. {antspymm-1.5.4 → antspymm-1.5.6}/docs/ppmi_step3_mm_nrg_csv.py +0 -0
  57. {antspymm-1.5.4 → antspymm-1.5.6}/docs/ppmi_step4_aggregate.py +0 -0
  58. {antspymm-1.5.4 → antspymm-1.5.6}/docs/ptbp_nrg.py +0 -0
  59. {antspymm-1.5.4 → antspymm-1.5.6}/docs/roi_visualization.py +0 -0
  60. {antspymm-1.5.4 → antspymm-1.5.6}/docs/roi_visualization_ppmi.py +0 -0
  61. {antspymm-1.5.4 → antspymm-1.5.6}/docs/rsfmri_run_minimal_localint.py +0 -0
  62. {antspymm-1.5.4 → antspymm-1.5.6}/docs/run_local_integration_scripts.py +0 -0
  63. {antspymm-1.5.4 → antspymm-1.5.6}/docs/run_mm_example.sh +0 -0
  64. {antspymm-1.5.4 → antspymm-1.5.6}/docs/template_overlays.py +0 -0
  65. {antspymm-1.5.4 → antspymm-1.5.6}/docs/ukbb_rsfmri.py +0 -0
  66. {antspymm-1.5.4 → antspymm-1.5.6}/docs/ukbb_to_nrg_processing.py +0 -0
  67. {antspymm-1.5.4 → antspymm-1.5.6}/docs/ukbb_to_nrg_processing2.py +0 -0
  68. {antspymm-1.5.4 → antspymm-1.5.6}/docs/visualize_tractogram.py +0 -0
  69. {antspymm-1.5.4 → antspymm-1.5.6}/setup.cfg +0 -0
  70. {antspymm-1.5.4 → antspymm-1.5.6}/tests/test_loop.py +0 -0
  71. {antspymm-1.5.4 → antspymm-1.5.6}/tests/test_nrg_validation.py +0 -0
  72. {antspymm-1.5.4 → antspymm-1.5.6}/tests/test_reference_run.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: antspymm
3
- Version: 1.5.4
3
+ Version: 1.5.6
4
4
  Summary: multi-channel/time-series medical image processing with antspyx
5
5
  Author-email: "Avants, Gosselin, Tustison, Reardon" <stnava@gmail.com>
6
6
  License: Apache-2.0
@@ -9,20 +9,18 @@ Classifier: Programming Language :: Python :: 3
9
9
  Classifier: Operating System :: OS Independent
10
10
  Requires-Python: >=3.9
11
11
  Description-Content-Type: text/markdown
12
- License-File: LICENSE
13
12
  Requires-Dist: h5py>=2.10.0
14
13
  Requires-Dist: numpy>=1.19.4
15
14
  Requires-Dist: pandas>=1.0.1
16
15
  Requires-Dist: antspyx>=0.4.2
17
- Requires-Dist: antspynet>=0.2.8
18
- Requires-Dist: antspyt1w>=0.9.3
16
+ Requires-Dist: antspynet>=0.2.9
17
+ Requires-Dist: antspyt1w>=0.9.8
19
18
  Requires-Dist: pathlib
20
19
  Requires-Dist: dipy
21
20
  Requires-Dist: nibabel
22
21
  Requires-Dist: scipy
23
22
  Requires-Dist: siq
24
23
  Requires-Dist: scikit-learn
25
- Dynamic: license-file
26
24
 
27
25
  # ANTsPyMM
28
26
 
@@ -65,6 +65,8 @@ from .mm import mc_denoise
65
65
  from .mm import mc_reg
66
66
  from .mm import dti_reg
67
67
  from .mm import timeseries_reg
68
+ from .mm import timeseries_transform
69
+ from .mm import copy_spatial_metadata_from_3d_to_4d
68
70
  from .mm import concat_dewarp
69
71
  from .mm import mc_resample_image_to_target
70
72
  from .mm import trim_dti_mask
@@ -136,5 +138,7 @@ from .mm import segment_timeseries_by_bvalue
136
138
  from .mm import shorten_pymm_names
137
139
  from .mm import pet3d_summary
138
140
  from .mm import deformation_gradient_optimized
141
+ from .mm import efficient_dwi_fit_voxelwise
142
+ from .mm import generate_voxelwise_bvecs
139
143
 
140
144
 
@@ -1794,6 +1794,89 @@ def merge_timeseries_data( img_LR, img_RL, allow_resample=True ):
1794
1794
  mimg.append( temp )
1795
1795
  return ants.list_to_ndimage( img_LR, mimg )
1796
1796
 
1797
+ def copy_spatial_metadata_from_3d_to_4d(spatial_img, timeseries_img):
1798
+ """
1799
+ Copy spatial metadata (origin, spacing, direction) from a 3D image to the
1800
+ spatial dimensions (first 3) of a 4D image, preserving the 4th dimension's metadata.
1801
+
1802
+ Parameters
1803
+ ----------
1804
+ spatial_img : ants.ANTsImage
1805
+ A 3D ANTsImage with the desired spatial metadata.
1806
+ timeseries_img : ants.ANTsImage
1807
+ A 4D ANTsImage to update.
1808
+
1809
+ Returns
1810
+ -------
1811
+ ants.ANTsImage
1812
+ A 4D ANTsImage with updated spatial metadata.
1813
+ """
1814
+ if spatial_img.dimension != 3:
1815
+ raise ValueError("spatial_img must be a 3D ANTsImage.")
1816
+ if timeseries_img.dimension != 4:
1817
+ raise ValueError("timeseries_img must be a 4D ANTsImage.")
1818
+ # Get 3D metadata
1819
+ spatial_origin = list(spatial_img.origin)
1820
+ spatial_spacing = list(spatial_img.spacing)
1821
+ spatial_direction = spatial_img.direction # 3x3
1822
+ # Get original 4D metadata
1823
+ ts_spacing = list(timeseries_img.spacing)
1824
+ ts_origin = list(timeseries_img.origin)
1825
+ ts_direction = timeseries_img.direction # 4x4
1826
+ # Replace only the first 3 entries for origin and spacing
1827
+ new_origin = spatial_origin + [ts_origin[3]]
1828
+ new_spacing = spatial_spacing + [ts_spacing[3]]
1829
+ # Replace top-left 3x3 block of direction matrix, preserve last row/column
1830
+ new_direction = ts_direction.copy()
1831
+ new_direction[:3, :3] = spatial_direction
1832
+ # Create updated image
1833
+ updated_img = ants.from_numpy(
1834
+ timeseries_img.numpy(),
1835
+ origin=new_origin,
1836
+ spacing=new_spacing,
1837
+ direction=new_direction
1838
+ )
1839
+ return updated_img
1840
+
1841
+ def timeseries_transform(transform, image, reference, interpolation='linear'):
1842
+ """
1843
+ Apply a spatial transform to each 3D volume in a 4D time series image.
1844
+
1845
+ Parameters
1846
+ ----------
1847
+ transform : ants transform object
1848
+ Path(s) to ANTs-compatible transform(s) to apply.
1849
+ image : ants.ANTsImage
1850
+ 4D input image with shape (X, Y, Z, T).
1851
+ reference : ants.ANTsImage
1852
+ Reference image to match in space.
1853
+ interpolation : str
1854
+ Interpolation method: 'linear', 'nearestNeighbor', etc.
1855
+
1856
+ Returns
1857
+ -------
1858
+ ants.ANTsImage
1859
+ 4D transformed image.
1860
+ """
1861
+ if image.dimension != 4:
1862
+ raise ValueError("Input image must be 4D (X, Y, Z, T).")
1863
+ n_volumes = image.shape[3]
1864
+ transformed_volumes = []
1865
+ for t in range(n_volumes):
1866
+ vol = ants.slice_image( image, 3, t )
1867
+ transformed = ants.apply_ants_transform_to_image(
1868
+ transform=transform,
1869
+ image=vol,
1870
+ reference=reference,
1871
+ interpolation=interpolation
1872
+ )
1873
+ transformed_volumes.append(transformed.numpy())
1874
+ # Stack along time axis and convert to ANTsImage
1875
+ transformed_array = np.stack(transformed_volumes, axis=-1)
1876
+ out_image = ants.from_numpy(transformed_array)
1877
+ out_image = ants.copy_image_info(image, out_image)
1878
+ out_image = copy_spatial_metadata_from_3d_to_4d(reference, out_image)
1879
+ return out_image
1797
1880
 
1798
1881
  def timeseries_reg(
1799
1882
  image,
@@ -3932,6 +4015,135 @@ def efficient_dwi_fit(gtab, diffusion_model, imagein, maskin,
3932
4015
  return full_fit, FA_img, MD_img, RGB_img
3933
4016
 
3934
4017
 
4018
+ def efficient_dwi_fit_voxelwise(imagein, maskin, bvals, bvecs_5d, model_params=None,
4019
+ bvals_to_use=None, num_threads=1, verbose=True):
4020
+ """
4021
+ Voxel-wise diffusion model fitting with individual b-vectors per voxel.
4022
+
4023
+ Parameters
4024
+ ----------
4025
+ imagein : ants.ANTsImage
4026
+ 4D DWI image (X, Y, Z, N).
4027
+ maskin : ants.ANTsImage
4028
+ 3D binary mask.
4029
+ bvals : (N,) array-like
4030
+ Common b-values across volumes.
4031
+ bvecs_5d : (X, Y, Z, N, 3) ndarray
4032
+ Voxel-specific b-vectors.
4033
+ model_params : dict
4034
+ Extra arguments for model.
4035
+ bvals_to_use : list[int]
4036
+ Subset of b-values to include.
4037
+ num_threads : int
4038
+ Number of threads to use.
4039
+ verbose : bool
4040
+ Whether to print status.
4041
+
4042
+ Returns
4043
+ -------
4044
+ FA_img : ants.ANTsImage
4045
+ Fractional anisotropy.
4046
+ MD_img : ants.ANTsImage
4047
+ Mean diffusivity.
4048
+ RGB_img : ants.ANTsImage
4049
+ RGB FA image.
4050
+ """
4051
+ import numpy as np
4052
+ import ants
4053
+ import dipy.reconst.dti as dti
4054
+ from dipy.core.gradients import gradient_table
4055
+ from dipy.reconst.dti import fractional_anisotropy, color_fa, mean_diffusivity
4056
+ from concurrent.futures import ThreadPoolExecutor
4057
+ from tqdm import tqdm
4058
+
4059
+ model_params = model_params or {}
4060
+ img = imagein.numpy()
4061
+ mask = maskin.numpy().astype(bool)
4062
+ X, Y, Z, N = img.shape
4063
+
4064
+ if bvals_to_use is not None:
4065
+ sel = np.isin(bvals, bvals_to_use)
4066
+ img = img[..., sel]
4067
+ bvals = bvals[sel]
4068
+ bvecs_5d = bvecs_5d[..., sel, :]
4069
+
4070
+ FA = np.zeros((X, Y, Z), dtype=np.float32)
4071
+ MD = np.zeros((X, Y, Z), dtype=np.float32)
4072
+ RGB = np.zeros((X, Y, Z, 3), dtype=np.float32)
4073
+
4074
+ def fit_voxel(ix, iy, iz):
4075
+ if not mask[ix, iy, iz]:
4076
+ return
4077
+ sig = img[ix, iy, iz, :]
4078
+ if np.all(sig == 0):
4079
+ return
4080
+ bv = bvecs_5d[ix, iy, iz, :, :]
4081
+ gtab = gradient_table(bvals, bv)
4082
+ try:
4083
+ model = dti.TensorModel(gtab, **model_params)
4084
+ fit = model.fit(sig)
4085
+ evals = fit.evals
4086
+ evecs = fit.evecs
4087
+ FA[ix, iy, iz] = fractional_anisotropy(evals)
4088
+ MD[ix, iy, iz] = mean_diffusivity(evals)
4089
+ RGB[ix, iy, iz, :] = color_fa(FA[ix, iy, iz], evecs)
4090
+ except Exception as e:
4091
+ if verbose:
4092
+ print(f"Voxel ({ix},{iy},{iz}) fit failed: {e}")
4093
+
4094
+ coords = np.argwhere(mask)
4095
+ if verbose:
4096
+ print(f"[INFO] Fitting {len(coords)} voxels using {num_threads} threads...")
4097
+
4098
+ if num_threads > 1:
4099
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
4100
+ list(tqdm(executor.map(lambda c: fit_voxel(*c), coords), total=len(coords)))
4101
+ else:
4102
+ for c in tqdm(coords):
4103
+ fit_voxel(*c)
4104
+
4105
+ ref = ants.slice_image(imagein, axis=3, idx=0)
4106
+ return (
4107
+ ants.copy_image_info(ref, ants.from_numpy(FA)),
4108
+ ants.copy_image_info(ref, ants.from_numpy(MD)),
4109
+ ants.merge_channels([ants.copy_image_info(ref, ants.from_numpy(RGB[..., i])) for i in range(3)])
4110
+ )
4111
+
4112
+
4113
+ def generate_voxelwise_bvecs(global_bvecs, voxel_rotations, transpose=False):
4114
+ """
4115
+ Generate voxel-wise b-vectors from a global bvec and voxel-wise rotation field.
4116
+
4117
+ Parameters
4118
+ ----------
4119
+ global_bvecs : ndarray of shape (N, 3)
4120
+ Global diffusion gradient directions.
4121
+ voxel_rotations : ndarray of shape (X, Y, Z, 3, 3)
4122
+ 3x3 rotation matrix for each voxel (can come from Jacobian of deformation field).
4123
+ transpose : bool, optional
4124
+ If True, transpose the rotation matrices before applying them to the b-vectors.
4125
+
4126
+
4127
+ Returns
4128
+ -------
4129
+ bvecs_5d : ndarray of shape (X, Y, Z, N, 3)
4130
+ Voxel-specific b-vectors.
4131
+ """
4132
+ X, Y, Z, _, _ = voxel_rotations.shape
4133
+ N = global_bvecs.shape[0]
4134
+ bvecs_5d = np.zeros((X, Y, Z, N, 3), dtype=np.float32)
4135
+
4136
+ for n in range(N):
4137
+ bvec = global_bvecs[n]
4138
+ for i in range(X):
4139
+ for j in range(Y):
4140
+ for k in range(Z):
4141
+ R = voxel_rotations[i, j, k]
4142
+ if transpose:
4143
+ R = R.T # Use transpose if needed
4144
+ bvecs_5d[i, j, k, n, :] = R @ bvec
4145
+ return bvecs_5d
4146
+
3935
4147
  def dipy_dti_recon(
3936
4148
  image,
3937
4149
  bvalsfn,
@@ -7434,7 +7646,7 @@ def mm(
7434
7646
  if do_kk:
7435
7647
  if verbose:
7436
7648
  print('kk')
7437
- output_dict['kk'] = antspyt1w.kelly_kapowski_thickness( hier['brain_n4_dnz'],
7649
+ output_dict['kk'] = antspyt1w.kelly_kapowski_thickness( t1atropos,
7438
7650
  labels=hier['dkt_parc']['dkt_cortex'], iterations=45 )
7439
7651
  if perfusion_image is not None:
7440
7652
  if perfusion_image.shape[3] > 1: # FIXME - better heuristic?
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: antspymm
3
- Version: 1.5.4
3
+ Version: 1.5.6
4
4
  Summary: multi-channel/time-series medical image processing with antspyx
5
5
  Author-email: "Avants, Gosselin, Tustison, Reardon" <stnava@gmail.com>
6
6
  License: Apache-2.0
@@ -9,20 +9,18 @@ Classifier: Programming Language :: Python :: 3
9
9
  Classifier: Operating System :: OS Independent
10
10
  Requires-Python: >=3.9
11
11
  Description-Content-Type: text/markdown
12
- License-File: LICENSE
13
12
  Requires-Dist: h5py>=2.10.0
14
13
  Requires-Dist: numpy>=1.19.4
15
14
  Requires-Dist: pandas>=1.0.1
16
15
  Requires-Dist: antspyx>=0.4.2
17
- Requires-Dist: antspynet>=0.2.8
18
- Requires-Dist: antspyt1w>=0.9.3
16
+ Requires-Dist: antspynet>=0.2.9
17
+ Requires-Dist: antspyt1w>=0.9.8
19
18
  Requires-Dist: pathlib
20
19
  Requires-Dist: dipy
21
20
  Requires-Dist: nibabel
22
21
  Requires-Dist: scipy
23
22
  Requires-Dist: siq
24
23
  Requires-Dist: scikit-learn
25
- Dynamic: license-file
26
24
 
27
25
  # ANTsPyMM
28
26
 
@@ -1,4 +1,3 @@
1
- LICENSE
2
1
  MANIFEST.in
3
2
  README.md
4
3
  pyproject.toml
@@ -25,7 +24,9 @@ docs/deepnbm.jpg
25
24
  docs/deformation_gradient_reo.py
26
25
  docs/describe_mm_data.R
27
26
  docs/dipy_dti_recon.py
27
+ docs/dti_distortion_correction_voxelwise_varying_bvectors_example_WIP.py
28
28
  docs/dti_recon.py
29
+ docs/dti_reconstruction_voxelwise_varying_bvectors_localint.py
29
30
  docs/dti_reg.py
30
31
  docs/dwi_rebasing.py
31
32
  docs/dwi_run.py
@@ -63,4 +64,5 @@ docs/ukbb_to_nrg_processing2.py
63
64
  docs/visualize_tractogram.py
64
65
  tests/test_loop.py
65
66
  tests/test_nrg_validation.py
66
- tests/test_reference_run.py
67
+ tests/test_reference_run.py
68
+ tests/voxelwise_bvec_dti_recon_test.py
@@ -2,8 +2,8 @@ h5py>=2.10.0
2
2
  numpy>=1.19.4
3
3
  pandas>=1.0.1
4
4
  antspyx>=0.4.2
5
- antspynet>=0.2.8
6
- antspyt1w>=0.9.3
5
+ antspynet>=0.2.9
6
+ antspyt1w>=0.9.8
7
7
  pathlib
8
8
  dipy
9
9
  nibabel
@@ -0,0 +1,238 @@
1
+ import numpy as np
2
+ import ants
3
+ import os
4
+ from dipy.io.gradients import read_bvals_bvecs
5
+ from scipy.stats import pearsonr
6
+ import antspymm
7
+ nt = 2
8
+
9
+ import numpy as np
10
+ from scipy.stats import pearsonr
11
+
12
+ def read_bvecs_rotated(bvec_file, rotmat):
13
+ bvecs = np.loadtxt(bvec_file)
14
+ if bvecs.shape[0] != 3:
15
+ bvecs = bvecs.T
16
+ rotated_bvecs = (rotmat @ bvecs).T
17
+ return rotated_bvecs
18
+
19
+ def mean_rgb_correlation(img1, img2, mask):
20
+ """
21
+ Compute the mean correlation between two RGB images.
22
+
23
+ Parameters
24
+ ----------
25
+ img1 : np.ndarray
26
+ First RGB image as a (H, W, 3) NumPy array.
27
+ img2 : np.ndarray
28
+ Second RGB image as a (H, W, 3) NumPy array.
29
+
30
+ Returns
31
+ -------
32
+ float
33
+ Mean Pearson correlation across the three RGB channels.
34
+ """
35
+ if img1.shape != img2.shape:
36
+ raise ValueError("Input images must have the same shape.")
37
+ correlations = []
38
+ img1c = ants.split_channels(img1)
39
+ img2c = ants.split_channels(img2)
40
+ for c in range(3): # R, G, B
41
+ x = extract_masked_values( img1c[c], mask)
42
+ y = extract_masked_values( img2c[c], mask)
43
+ if np.std(x) == 0 or np.std(y) == 0:
44
+ corr = 0.0 # Handle flat images
45
+ else:
46
+ corr, _ = pearsonr(x, y)
47
+ correlations.append(corr)
48
+ return np.mean(correlations)
49
+
50
+ import numpy as np
51
+ import ants
52
+
53
+ def mean_rgb_mae(img1, img2, mask):
54
+ """
55
+ Compute the mean absolute error (MAE) between two RGB images.
56
+
57
+ Parameters
58
+ ----------
59
+ img1 : np.ndarray
60
+ First RGB image as a (H, W, 3) NumPy array.
61
+ img2 : np.ndarray
62
+ Second RGB image as a (H, W, 3) NumPy array.
63
+ mask : ants.ANTsImage
64
+ Binary mask defining valid pixels for error calculation.
65
+
66
+ Returns
67
+ -------
68
+ float
69
+ Mean absolute error across the three RGB channels.
70
+ """
71
+ if img1.shape != img2.shape:
72
+ raise ValueError("Input images must have the same shape.")
73
+
74
+ mae_values = []
75
+ img1c = ants.split_channels(img1)
76
+ img2c = ants.split_channels(img2)
77
+
78
+ for c in range(3): # R, G, B
79
+ x = extract_masked_values(img1c[c], mask)
80
+ y = extract_masked_values(img2c[c], mask)
81
+ mae = np.mean(np.abs(x - y))
82
+ mae_values.append(mae)
83
+
84
+ return np.mean(mae_values)
85
+
86
+ def extract_masked_values(image, mask):
87
+ return image.numpy()[mask.numpy() > 0]
88
+
89
+ import numpy as np
90
+
91
+ def verify_unit_bvecs(bvecs, tol=1e-5):
92
+ """
93
+ Verifies that each b-vector has unit norm within a tolerance.
94
+
95
+ Parameters
96
+ ----------
97
+ bvecs : array-like, shape (N, 3)
98
+ Array of b-vectors, one per diffusion direction.
99
+ tol : float
100
+ Tolerance for unit norm.
101
+
102
+ Returns
103
+ -------
104
+ is_unit : np.ndarray, shape (N,)
105
+ Boolean array indicating if each b-vector is unit norm.
106
+ norms : np.ndarray, shape (N,)
107
+ Norms of each b-vector.
108
+ """
109
+ bvecs = np.asarray(bvecs)
110
+ norms = np.linalg.norm(bvecs, axis=1)
111
+ is_unit = np.abs(norms - 1) < tol
112
+ return is_unit, norms
113
+
114
+ # def test_efficient_dwi_fit_voxelwise_distortion_correction():
115
+ if True:
116
+ print("simple test for distortion_correction consistency of efficient_dwi_fit_voxelwise")
117
+ ex_path = os.path.expanduser( "~/.antspyt1w/" )
118
+ ex_path_mm = os.path.expanduser( "~/.antspymm/" )
119
+ #### Load in data ####
120
+ print("Load in subject data ...")
121
+ lrid = "I1499279_Anon_20210819142214_5"
122
+ rlid = "I1499337_Anon_20210819142214_6"
123
+ # Load paths
124
+ print("📁 Loading subject LR data...")
125
+ lrid = os.path.join(ex_path_mm, lrid )
126
+ img_LR_in = ants.image_read(lrid + '.nii.gz')
127
+ img_LR_in_avg = ants.get_average_of_timeseries( img_LR_in )
128
+ mask = img_LR_in_avg.get_mask()
129
+ bvals, bvecs = read_bvals_bvecs(lrid + '.bval', lrid + '.bvec')
130
+ bvecs = np.asarray(bvecs)
131
+ shape = img_LR_in.shape[:3]
132
+
133
+ print("📁 Loading subject RL data...")
134
+ rlid = os.path.join(ex_path_mm, rlid )
135
+ img_RL_in = ants.image_read(rlid + '.nii.gz')
136
+ bvalsRL, bvecsRL = read_bvals_bvecs(rlid + '.bval', rlid + '.bvec')
137
+
138
+ img_RL_in_avg = ants.get_average_of_timeseries( img_RL_in )
139
+ maskRL = img_RL_in_avg.get_mask()
140
+
141
+ print("🧠 Running baseline LR fit...")
142
+ bvecs_5d_orig = np.broadcast_to(bvecs, shape + bvecs.shape).copy()
143
+ if not "FA_orig" in globals():
144
+ FA_orig, MD_orig, RGB_orig = antspymm.efficient_dwi_fit_voxelwise(
145
+ imagein=img_LR_in,
146
+ maskin=mask,
147
+ bvals=bvals,
148
+ bvecs_5d=bvecs_5d_orig,
149
+ model_params={},
150
+ bvals_to_use=None,
151
+ num_threads=nt,
152
+ verbose=False
153
+ )
154
+
155
+ bvecs_5d_origRL = np.broadcast_to(np.asarray(bvecsRL), shape + bvecsRL.shape).copy()
156
+ if not "FA_origRL" in globals():
157
+ FA_origRL, MD_origRL, RGB_origRL = antspymm.efficient_dwi_fit_voxelwise(
158
+ imagein=img_RL_in,
159
+ maskin=maskRL,
160
+ bvals=bvalsRL,
161
+ bvecs_5d=bvecs_5d_origRL,
162
+ model_params={},
163
+ bvals_to_use=None,
164
+ num_threads=nt,
165
+ verbose=False
166
+ )
167
+
168
+ print("dist corr")
169
+ if not "mytx" in globals():
170
+ mytx = ants.registration( FA_orig, FA_origRL, 'SyNBold' )
171
+ mytx2 = ants.apply_transforms(FA_orig, FA_origRL, mytx['fwdtransforms'],
172
+ interpolator='linear', imagetype=0, compose='/tmp/comptx' )
173
+ print( mytx2 )
174
+
175
+ print("🔄 now with distortion correction...")
176
+ mydef = ants.image_read( mytx2 )
177
+ mywarp = ants.transform_from_displacement_field( mydef )
178
+ img_w = antspymm.timeseries_transform(mywarp, img_RL_in, reference=img_LR_in_avg)
179
+ mask_w = ants.apply_ants_transform_to_image(mywarp, maskRL, reference=img_LR_in_avg, interpolation='nearestNeighbor')
180
+ print("🧠 Running warped fit...")
181
+ if not "FA_w" in globals():
182
+ bvecsRL = np.asarray(bvecsRL)
183
+ mydefgrad = antspymm.deformation_gradient_optimized( mydef,
184
+ to_rotation=False, to_inverse_rotation=True )
185
+ bvecsRLw = antspymm.generate_voxelwise_bvecs( bvecsRL, mydefgrad, transpose=False )
186
+ FA_w, MD_w, RGB_w = antspymm.efficient_dwi_fit_voxelwise(
187
+ imagein=img_w,
188
+ maskin=mask_w,
189
+ bvals=bvalsRL,
190
+ bvecs_5d=bvecsRLw,
191
+ model_params={},
192
+ bvals_to_use=None,
193
+ num_threads=nt,
194
+ verbose=False
195
+ )
196
+
197
+ if not "FA_w2" in globals():
198
+ bvecsRL = np.asarray(bvecsRL)
199
+ FA_w2, MD_w2, RGB_w2 = antspymm.efficient_dwi_fit_voxelwise(
200
+ imagein=img_w,
201
+ maskin=mask_w,
202
+ bvals=bvalsRL,
203
+ bvecs_5d=bvecs_5d_origRL,
204
+ model_params={},
205
+ bvals_to_use=None,
206
+ num_threads=nt,
207
+ verbose=False
208
+ )
209
+
210
+ fff = mean_rgb_correlation
211
+ print("📊 Comparing results...")
212
+ maskJoined = ants.threshold_image( mask + mask_w, 1.05, 2.0 )
213
+ maske=ants.iMath(maskJoined,'ME',3)
214
+ fa_corr = fff( RGB_orig, RGB_w, maske )
215
+ print(f"✅ FA correlation (original vs distortion corrected): {fa_corr:.4f}")
216
+
217
+ fa_corrX = fff( RGB_orig, RGB_w2, maske )
218
+ print(f"✅ FA correlation (original vs distortion corrected global recon): {fa_corrX:.4f}")
219
+
220
+ RGB_origRLc = ants.split_channels(RGB_origRL)
221
+ for c in range(3):
222
+ RGB_origRLc[c] = ants.apply_ants_transform_to_image(
223
+ mywarp, RGB_origRLc[c],
224
+ reference=img_LR_in_avg, interpolation='linear'
225
+ )
226
+
227
+ fa_corrY = fff( RGB_orig, ants.merge_channels(RGB_origRLc), maske )
228
+ print(f"✅ FA correlation (original vs warped RGB global recon): {fa_corrY:.4f}")
229
+
230
+ # assert fa_corr > 0.80, "FA correlation too low"
231
+
232
+ print("🎉 Test passed: model is distortion-consistent.")
233
+
234
+ # ants.image_write( FA_orig, '/tmp/xxx.nii.gz' )
235
+ # ants.image_write( FA_w, '/tmp/yyy.nii.gz' )
236
+ #
237
+ # Example usage:
238
+ # test_efficient_dwi_fit_voxelwise_distortion_correction()