nabu 2025.1.0.dev14__py3-none-any.whl → 2025.1.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. doc/doc_config.py +32 -0
  2. nabu/__init__.py +1 -1
  3. nabu/app/cast_volume.py +9 -1
  4. nabu/app/cli_configs.py +80 -3
  5. nabu/app/estimate_motion.py +54 -0
  6. nabu/app/multicor.py +2 -4
  7. nabu/app/pcaflats.py +116 -0
  8. nabu/app/reconstruct.py +1 -7
  9. nabu/app/reduce_dark_flat.py +5 -2
  10. nabu/estimation/cor.py +1 -1
  11. nabu/estimation/motion.py +557 -0
  12. nabu/estimation/tests/test_motion_estimation.py +471 -0
  13. nabu/estimation/tilt.py +1 -1
  14. nabu/estimation/translation.py +47 -1
  15. nabu/io/cast_volume.py +100 -13
  16. nabu/io/reader.py +32 -1
  17. nabu/io/tests/test_remove_volume.py +152 -0
  18. nabu/pipeline/config_validators.py +42 -43
  19. nabu/pipeline/estimators.py +255 -0
  20. nabu/pipeline/fullfield/chunked.py +67 -43
  21. nabu/pipeline/fullfield/chunked_cuda.py +5 -2
  22. nabu/pipeline/fullfield/nabu_config.py +20 -14
  23. nabu/pipeline/fullfield/processconfig.py +17 -3
  24. nabu/pipeline/fullfield/reconstruction.py +4 -1
  25. nabu/pipeline/params.py +12 -0
  26. nabu/pipeline/tests/test_estimators.py +240 -3
  27. nabu/preproc/ccd.py +53 -3
  28. nabu/preproc/flatfield.py +306 -1
  29. nabu/preproc/shift.py +3 -1
  30. nabu/preproc/tests/test_pcaflats.py +154 -0
  31. nabu/processing/rotation_cuda.py +3 -1
  32. nabu/processing/tests/test_rotation.py +4 -2
  33. nabu/reconstruction/astra.py +245 -0
  34. nabu/reconstruction/fbp.py +7 -0
  35. nabu/reconstruction/fbp_base.py +31 -7
  36. nabu/reconstruction/fbp_opencl.py +8 -0
  37. nabu/reconstruction/filtering_opencl.py +2 -0
  38. nabu/reconstruction/mlem.py +47 -13
  39. nabu/reconstruction/tests/test_filtering.py +13 -2
  40. nabu/reconstruction/tests/test_mlem.py +91 -62
  41. nabu/resources/dataset_analyzer.py +144 -20
  42. nabu/resources/nxflatfield.py +101 -35
  43. nabu/resources/tests/test_nxflatfield.py +1 -1
  44. nabu/resources/utils.py +16 -10
  45. nabu/stitching/alignment.py +7 -7
  46. nabu/stitching/config.py +22 -20
  47. nabu/stitching/definitions.py +2 -2
  48. nabu/stitching/overlap.py +4 -4
  49. nabu/stitching/sample_normalization.py +5 -5
  50. nabu/stitching/stitcher/post_processing.py +5 -3
  51. nabu/stitching/stitcher/pre_processing.py +24 -20
  52. nabu/stitching/tests/test_config.py +3 -3
  53. nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
  54. nabu/stitching/tests/test_z_postprocessing_stitching.py +2 -2
  55. nabu/stitching/tests/test_z_preprocessing_stitching.py +23 -20
  56. nabu/stitching/utils/utils.py +7 -7
  57. nabu/testutils.py +1 -4
  58. nabu/utils.py +13 -0
  59. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/METADATA +3 -4
  60. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/RECORD +64 -57
  61. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/WHEEL +1 -1
  62. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/entry_points.txt +2 -1
  63. nabu/app/correct_rot.py +0 -62
  64. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/licenses/LICENSE +0 -0
  65. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/top_level.txt +0 -0
nabu/preproc/shift.py CHANGED
@@ -51,7 +51,7 @@ class VerticalShift:
51
51
  assert np.max(iangles) < len(self.interp_infos)
52
52
  assert len(iangles) == radios.shape[0]
53
53
 
54
- def apply_vertical_shifts(self, radios, iangles, output=None):
54
+ def apply_vertical_shifts(self, radios, iangles=None, output=None):
55
55
  """
56
56
  Parameters
57
57
  ----------
@@ -65,6 +65,8 @@ class VerticalShift:
65
65
  If given, it will be modified to contain the shifted radios.
66
66
  Must be of the same shape of `radios`.
67
67
  """
68
+ if iangles is None:
69
+ iangles = np.arange(radios.shape[0])
68
70
  self._check(radios, iangles)
69
71
 
70
72
  newradio = np.zeros_like(radios[0])
@@ -0,0 +1,154 @@
1
+ import os
2
+ import numpy as np
3
+ import pytest
4
+
5
+ import h5py
6
+ from nabu.testutils import utilstest
7
+ from nabu.preproc.flatfield import (
8
+ PCAFlatsDecomposer,
9
+ PCAFlatsNormalizer,
10
+ )
11
+
12
+
13
+ @pytest.fixture(scope="class")
14
+ def bootstrap_pcaflats(request):
15
+ cls = request.cls
16
+ # TODO: these tolerances for having the tests passed should be tighter.
17
+ # Discrepancies between id11 code and nabu code are still mysterious.
18
+ cls.mean_abs_tol = 1e-1
19
+ cls.comps_abs_tol = 1e-2
20
+ cls.projs, cls.flats, cls.darks = get_pcaflats_data("test_pcaflats.npz")
21
+ cls.raw_projs = cls.projs.copy() # Needed because flat correction is done inplace.
22
+ ref_data = get_pcaflats_refdata("ref_pcaflats.npz")
23
+ cls.mean = ref_data["mean"]
24
+ cls.components_3 = ref_data["components_3"]
25
+ cls.components_15 = ref_data["components_15"]
26
+ cls.dark = ref_data["dark"]
27
+ cls.normalized_projs_3 = ref_data["normalized_projs_3"]
28
+ cls.normalized_projs_15 = ref_data["normalized_projs_15"]
29
+ cls.normalized_projs_custom_mask = ref_data["normalized_projs_custom_mask"]
30
+ cls.test_normalize_projs_custom_prop = ref_data["normalized_projs_custom_prop"]
31
+
32
+ cls.h5_filename_3 = get_h5_pcaflats("pcaflat_3.h5")
33
+ cls.h5_filename_15 = get_h5_pcaflats("pcaflat_15.h5")
34
+
35
+
36
+ def get_pcaflats_data(*dataset_path):
37
+ """
38
+ Get a dataset file from silx.org/pub/nabu/data
39
+ dataset_args is a list describing a nested folder structures, ex.
40
+ ["path", "to", "my", "dataset.h5"]
41
+ """
42
+ dataset_relpath = os.path.join(*dataset_path)
43
+ dataset_downloaded_path = utilstest.getfile(dataset_relpath)
44
+ data = np.load(dataset_downloaded_path)
45
+ projs = data["projs"].astype(np.float32)
46
+ flats = data["flats"].astype(np.float32)
47
+ darks = data["darks"].astype(np.float32)
48
+
49
+ return projs, flats, darks
50
+
51
+
52
+ def get_h5_pcaflats(*dataset_path):
53
+ """
54
+ Get a dataset file from silx.org/pub/nabu/data
55
+ dataset_args is a list describing a nested folder structures, ex.
56
+ ["path", "to", "my", "dataset.h5"]
57
+ """
58
+ dataset_relpath = os.path.join(*dataset_path)
59
+ dataset_downloaded_path = utilstest.getfile(dataset_relpath)
60
+
61
+ return dataset_downloaded_path
62
+
63
+
64
+ def get_pcaflats_refdata(*dataset_path):
65
+ """
66
+ Get a dataset file from silx.org/pub/nabu/data
67
+ dataset_args is a list describing a nested folder structures, ex.
68
+ ["path", "to", "my", "dataset.h5"]
69
+ """
70
+ dataset_relpath = os.path.join(*dataset_path)
71
+ dataset_downloaded_path = utilstest.getfile(dataset_relpath)
72
+ data = np.load(dataset_downloaded_path)
73
+
74
+ return data
75
+
76
+
77
+ def get_decomposition(filename):
78
+ with h5py.File(filename, "r") as f:
79
+ # Load the dataset
80
+ p_comps = f["entry0000/p_components"][()]
81
+ p_mean = f["entry0000/p_mean"][()]
82
+ dark = f["entry0000/dark"][()]
83
+ return p_comps, p_mean, dark
84
+
85
+
86
+ @pytest.mark.usefixtures("bootstrap_pcaflats")
87
+ class TestPCAFlatsDecomposer:
88
+ def test_decompose_flats(self):
89
+ # Build 3-sigma basis
90
+ pca = PCAFlatsDecomposer(self.flats, self.darks, nsigma=3)
91
+ message = f"Found a discrepency between computed mean flat and reference."
92
+ assert np.allclose(self.mean, pca.mean, atol=self.mean_abs_tol), message
93
+ message = f"Found a discrepency between computed components and reference ones if nsigma=3."
94
+ assert np.allclose(self.components_3, np.array(pca.components), atol=self.comps_abs_tol), message
95
+
96
+ # Build 1.5-sigma basis
97
+ pca = PCAFlatsDecomposer(self.flats, self.darks, nsigma=1.5)
98
+ message = f"Found a discrepency between computed components and reference ones, if nsigma=1.5."
99
+ assert np.allclose(self.components_15, np.array(pca.components), atol=self.comps_abs_tol), message
100
+
101
+ def test_save_load_decomposition(self):
102
+ pca = PCAFlatsDecomposer(self.flats, self.darks, nsigma=3)
103
+ tmp_path = os.path.join(os.path.dirname(self.h5_filename_3), "PCA_Flats.h5")
104
+ pca.save_decomposition(path=tmp_path)
105
+ p_comps, p_mean, dark = get_decomposition(tmp_path)
106
+ message = f"Found a discrepency between saved and loaded mean flat."
107
+ assert np.allclose(self.mean, p_mean, atol=self.mean_abs_tol), message
108
+ message = f"Found a discrepency between saved and loaded components if nsigma=3."
109
+ assert np.allclose(self.components_3, p_comps, atol=self.comps_abs_tol), message
110
+ message = f"Found a discrepency between saved and loaded dark."
111
+ assert np.allclose(self.dark, dark, atol=self.comps_abs_tol), message
112
+ # Clean up
113
+ if os.path.exists(tmp_path):
114
+ os.remove(tmp_path)
115
+
116
+
117
+ @pytest.mark.usefixtures("bootstrap_pcaflats")
118
+ class TestPCAFlatsNormalizer:
119
+ def test_load_pcaflats(self):
120
+ """Tests that the structure of the output PCAFlat h5 file is correct."""
121
+ p_comps, p_mean, dark = get_decomposition(self.h5_filename_3)
122
+ # Check the shape of the loaded data
123
+ assert p_comps.shape[1:] == p_mean.shape
124
+ assert p_comps.shape[1:] == dark.shape
125
+
126
+ def test_normalize_projs(self):
127
+ p_comps, p_mean, dark = get_decomposition(self.h5_filename_3)
128
+ pca = PCAFlatsNormalizer(p_comps, dark, p_mean)
129
+ projs = self.raw_projs.copy()
130
+ pca.normalize_radios(projs)
131
+ assert np.allclose(projs, self.normalized_projs_3, atol=1e-2)
132
+ p_comps, p_mean, dark = get_decomposition(self.h5_filename_15)
133
+ pca = PCAFlatsNormalizer(p_comps, dark, p_mean)
134
+ projs = self.raw_projs.copy()
135
+ pca.normalize_radios(projs)
136
+ assert np.allclose(projs, self.normalized_projs_15, atol=1e-2)
137
+
138
+ def test_use_custom_mask(self):
139
+ mask = np.zeros(self.mean.shape, dtype=bool)
140
+ mask[:, :10] = True
141
+ mask[:, -10:] = True
142
+ p_comps, p_mean, dark = get_decomposition(self.h5_filename_3)
143
+
144
+ pca = PCAFlatsNormalizer(p_comps, dark, p_mean)
145
+ projs = self.raw_projs.copy()
146
+ pca.normalize_radios(projs, mask=mask)
147
+ assert np.allclose(projs, self.normalized_projs_custom_mask, atol=1e-2)
148
+
149
+ def test_change_mask_prop(self):
150
+ p_comps, p_mean, dark = get_decomposition(self.h5_filename_3)
151
+ pca = PCAFlatsNormalizer(p_comps, dark, p_mean)
152
+ projs = self.raw_projs.copy()
153
+ pca.normalize_radios(projs, prop=0.05)
154
+ assert np.allclose(projs, self.test_normalize_projs_custom_prop, atol=1e-2)
@@ -1,7 +1,7 @@
1
1
  import numpy as np
2
2
  from .rotation import Rotation
3
3
  from ..utils import get_cuda_srcfile, updiv
4
- from ..cuda.utils import __has_pycuda__, copy_array
4
+ from ..cuda.utils import __has_pycuda__, copy_array, check_textures_availability
5
5
  from ..cuda.processing import CudaProcessing
6
6
 
7
7
  if __has_pycuda__:
@@ -11,6 +11,8 @@ if __has_pycuda__:
11
11
 
12
12
  class CudaRotation(Rotation):
13
13
  def __init__(self, shape, angle, center=None, mode="edge", reshape=False, cuda_options=None, **sk_kwargs):
14
+ if not (check_textures_availability()):
15
+ raise RuntimeError("Need cuda textures for this class")
14
16
  if center is None:
15
17
  center = ((shape[1] - 1) / 2.0, (shape[0] - 1) / 2.0)
16
18
  super().__init__(shape, angle, center=center, mode=mode, reshape=reshape, **sk_kwargs)
@@ -3,7 +3,7 @@ import pytest
3
3
  from nabu.testutils import generate_tests_scenarios
4
4
  from nabu.processing.rotation_cuda import Rotation
5
5
  from nabu.processing.rotation import __have__skimage__
6
- from nabu.cuda.utils import __has_pycuda__, get_cuda_context
6
+ from nabu.cuda.utils import __has_pycuda__, get_cuda_context, check_textures_availability
7
7
 
8
8
  if __have__skimage__:
9
9
  from skimage.transform import rotate
@@ -68,7 +68,9 @@ class TestRotation:
68
68
  res = R(self.image)
69
69
  self._check_result(res, config, 1e-6)
70
70
 
71
- @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda rotation")
71
+ @pytest.mark.skipif(
72
+ not (__has_pycuda__) or not (check_textures_availability()), reason="Need cuda rotation (and textures)"
73
+ )
72
74
  @pytest.mark.parametrize("config", scenarios)
73
75
  def test_cuda_rotation(self, config):
74
76
  R = CudaRotation(
@@ -0,0 +1,245 @@
1
+ # ruff: noqa
2
+ try:
3
+ import astra
4
+
5
+ __have_astra__ = True
6
+ except ImportError:
7
+ __have_astra__ = False
8
+ astra = None
9
+
10
+
11
+ class AstraReconstructor:
12
+ """
13
+ Base class for reconstructors based on the Astra toolbox
14
+ """
15
+
16
+ default_extra_options = {
17
+ "axis_correction": None,
18
+ "clip_outer_circle": False,
19
+ "scale_factor": None,
20
+ "filter_cutoff": 1.0,
21
+ "outer_circle_value": 0.0,
22
+ }
23
+
24
+ def __init__(
25
+ self,
26
+ sinos_shape,
27
+ angles=None,
28
+ volume_shape=None,
29
+ rot_center=None,
30
+ pixel_size=None,
31
+ padding_mode="zeros",
32
+ filter_name=None,
33
+ slice_roi=None,
34
+ cuda_options=None,
35
+ extra_options=None,
36
+ ):
37
+ self._configure_extra_options(extra_options)
38
+ self._init_cuda(cuda_options)
39
+ self._set_sino_shape(sinos_shape)
40
+ self._orig_prog_geom = None
41
+ self._init_geometry(
42
+ source_origin_dist,
43
+ origin_detector_dist,
44
+ pixel_size,
45
+ angles,
46
+ volume_shape,
47
+ rot_center,
48
+ relative_z_position,
49
+ slice_roi,
50
+ )
51
+ self._init_fdk(padding_mode, filter_name)
52
+ self._alg_id = None
53
+ self._vol_id = None
54
+ self._proj_id = None
55
+
56
+ def _configure_extra_options(self, extra_options):
57
+ self.extra_options = self.default_extra_options.copy()
58
+ self.extra_options.update(extra_options or {})
59
+
60
+ def _init_cuda(self, cuda_options):
61
+ cuda_options = cuda_options or {}
62
+ self.cuda = CudaProcessing(**cuda_options)
63
+
64
+ def _set_sino_shape(self, sinos_shape):
65
+ if len(sinos_shape) != 3:
66
+ raise ValueError("Expected a 3D shape")
67
+ self.sinos_shape = sinos_shape
68
+ self.n_sinos, self.n_angles, self.prj_width = sinos_shape
69
+
70
+ def _set_pixel_size(self, pixel_size):
71
+ if pixel_size is None:
72
+ det_spacing_y = det_spacing_x = 1
73
+ elif np.iterable(pixel_size):
74
+ det_spacing_y, det_spacing_x = pixel_size
75
+ else:
76
+ # assuming scalar
77
+ det_spacing_y = det_spacing_x = pixel_size
78
+ self._det_spacing_y = det_spacing_y
79
+ self._det_spacing_x = det_spacing_x
80
+
81
+ def _set_slice_roi(self, slice_roi):
82
+ self.slice_roi = slice_roi
83
+ self._vol_geom_n_x = self.n_x
84
+ self._vol_geom_n_y = self.n_y
85
+ self._crop_data = True
86
+ if slice_roi is None:
87
+ return
88
+ start_x, end_x, start_y, end_y = slice_roi
89
+ if roi_is_centered(self.volume_shape[1:], (slice(start_y, end_y), slice(start_x, end_x))):
90
+ # Astra can only reconstruct subregion centered around the origin
91
+ self._vol_geom_n_x = self.n_x - start_x * 2
92
+ self._vol_geom_n_y = self.n_y - start_y * 2
93
+ else:
94
+ raise NotImplementedError(
95
+ "Astra supports only slice_roi centered around origin (got slice_roi=%s with n_x=%d, n_y=%d)"
96
+ % (str(slice_roi), self.n_x, self.n_y)
97
+ )
98
+
99
+ def _init_geometry(
100
+ self,
101
+ source_origin_dist,
102
+ origin_detector_dist,
103
+ pixel_size,
104
+ angles,
105
+ volume_shape,
106
+ rot_center,
107
+ relative_z_position,
108
+ slice_roi,
109
+ ):
110
+ if angles is None:
111
+ self.angles = np.linspace(0, 2 * np.pi, self.n_angles, endpoint=True)
112
+ else:
113
+ self.angles = angles
114
+ if volume_shape is None:
115
+ volume_shape = (self.sinos_shape[0], self.sinos_shape[2], self.sinos_shape[2])
116
+ self.volume_shape = volume_shape
117
+ self.n_z, self.n_y, self.n_x = self.volume_shape
118
+ self.source_origin_dist = source_origin_dist
119
+ self.origin_detector_dist = origin_detector_dist
120
+ self.magnification = 1 + origin_detector_dist / source_origin_dist
121
+ self._set_slice_roi(slice_roi)
122
+ self.vol_geom = astra.create_vol_geom(self._vol_geom_n_y, self._vol_geom_n_x, self.n_z)
123
+ self.vol_shape = astra.geom_size(self.vol_geom)
124
+ self._cor_shift = 0.0
125
+ self.rot_center = rot_center
126
+ if rot_center is not None:
127
+ self._cor_shift = (self.sinos_shape[-1] - 1) / 2.0 - rot_center
128
+ self._set_pixel_size(pixel_size)
129
+ self._axis_corrections = self.extra_options.get("axis_correction", None)
130
+ self._create_astra_proj_geometry(relative_z_position)
131
+
132
+ def _create_astra_proj_geometry(self, relative_z_position):
133
+ # This object has to be re-created each time, because once the modifications below are done,
134
+ # it is no more a "cone" geometry but a "cone_vec" geometry, and cannot be updated subsequently
135
+ # (see astra/functions.py:271)
136
+ self.proj_geom = astra.create_proj_geom(
137
+ "cone",
138
+ self._det_spacing_x,
139
+ self._det_spacing_y,
140
+ self.n_sinos,
141
+ self.prj_width,
142
+ self.angles,
143
+ self.source_origin_dist,
144
+ self.origin_detector_dist,
145
+ )
146
+ self.relative_z_position = relative_z_position or 0.0
147
+ # This will turn the geometry of type "cone" into a geometry of type "cone_vec"
148
+ if self._orig_prog_geom is None:
149
+ self._orig_prog_geom = self.proj_geom
150
+ self.proj_geom = astra.geom_postalignment(self.proj_geom, (self._cor_shift, 0))
151
+ # (src, detector_center, u, v) = (srcX, srcY, srcZ, dX, dY, dZ, uX, uY, uZ, vX, vY, vZ)
152
+ vecs = self.proj_geom["Vectors"]
153
+
154
+ # To adapt the center of rotation:
155
+ # dX = cor_shift * cos(theta) - origin_detector_dist * sin(theta)
156
+ # dY = origin_detector_dist * cos(theta) + cor_shift * sin(theta)
157
+ if self._axis_corrections is not None:
158
+ # should we check that dX and dY match the above formulas ?
159
+ cor_shifts = self._cor_shift + self._axis_corrections
160
+ vecs[:, 3] = cor_shifts * np.cos(self.angles) - self.origin_detector_dist * np.sin(self.angles)
161
+ vecs[:, 4] = self.origin_detector_dist * np.cos(self.angles) + cor_shifts * np.sin(self.angles)
162
+
163
+ # To adapt the z position:
164
+ # Component 2 of vecs is the z coordinate of the source, component 5 is the z component of the detector position
165
+ # We need to re-create the same inclination of the cone beam, thus we need to keep the inclination of the two z positions.
166
+ # The detector is centered on the rotation axis, thus moving it up or down, just moves it out of the reconstruction volume.
167
+ # We can bring back the detector in the correct volume position, by applying a rigid translation of both the detector and the source.
168
+ # The translation is exactly the amount that brought the detector up or down, but in the opposite direction.
169
+ vecs[:, 2] = -self.relative_z_position
170
+
171
+ def _set_output(self, volume):
172
+ if volume is not None:
173
+ expected_shape = self.vol_shape # if not (self._crop_data) else self._output_cropped_shape
174
+ self.cuda.check_array(volume, expected_shape)
175
+ self.cuda.set_array("output", volume)
176
+ if volume is None:
177
+ self.cuda.allocate_array("output", self.vol_shape)
178
+ d_volume = self.cuda.get_array("output")
179
+ z, y, x = d_volume.shape
180
+ self._vol_link = astra.data3d.GPULink(d_volume.ptr, x, y, z, d_volume.strides[-2])
181
+ self._vol_id = astra.data3d.link("-vol", self.vol_geom, self._vol_link)
182
+
183
+ def _set_input(self, sinos):
184
+ self.cuda.check_array(sinos, self.sinos_shape)
185
+ self.cuda.set_array("sinos", sinos) # self.cuda.sinos is now a GPU array
186
+ # TODO don't create new link/proj_id if ptr is the same ?
187
+ # But it seems Astra modifies the input sinogram while doing FDK, so this might be not relevant
188
+ d_sinos = self.cuda.get_array("sinos")
189
+
190
+ # self._proj_data_link = astra.data3d.GPULink(d_sinos.ptr, self.prj_width, self.n_angles, self.n_z, sinos.strides[-2])
191
+ self._proj_data_link = astra.data3d.GPULink(
192
+ d_sinos.ptr, self.prj_width, self.n_angles, self.n_sinos, d_sinos.strides[-2]
193
+ )
194
+ self._proj_id = astra.data3d.link("-sino", self.proj_geom, self._proj_data_link)
195
+
196
+ def _preprocess_data(self):
197
+ d_sinos = self.cuda.sinos
198
+ for i in range(d_sinos.shape[0]):
199
+ self.sino_filter.filter_sino(d_sinos[i], output=d_sinos[i])
200
+
201
+ def _update_reconstruction(self):
202
+ cfg = astra.astra_dict("BP3D_CUDA")
203
+ cfg["ReconstructionDataId"] = self._vol_id
204
+ cfg["ProjectionDataId"] = self._proj_id
205
+ if self._alg_id is not None:
206
+ astra.algorithm.delete(self._alg_id)
207
+ self._alg_id = astra.algorithm.create(cfg)
208
+
209
+ def reconstruct(self, sinos, output=None, relative_z_position=None):
210
+ """
211
+ sinos: numpy.ndarray or pycuda.gpuarray
212
+ Sinograms, with shape (n_sinograms, n_angles, width)
213
+ output: pycuda.gpuarray, optional
214
+ Output array. If not provided, a new numpy array is returned
215
+ relative_z_position: int, optional
216
+ Position of the central slice of the slab, with respect to the full stack of slices.
217
+ By default it is set to zero, meaning that the current slab is assumed in the middle of the stack
218
+ """
219
+ self._create_astra_proj_geometry(relative_z_position)
220
+ self._set_input(sinos)
221
+ self._set_output(output)
222
+ self._preprocess_data()
223
+ self._update_reconstruction()
224
+ astra.algorithm.run(self._alg_id)
225
+ #
226
+ # NB: Could also be done with
227
+ # from astra.experimental import direct_BP3D
228
+ # projector_id = astra.create_projector("cuda3d", self.proj_geom, self.vol_geom, options=None)
229
+ # direct_BP3D(projector_id, self._vol_link, self._proj_data_link)
230
+ #
231
+ result = self.cuda.get_array("output")
232
+ if output is None:
233
+ result = result.get()
234
+ if self.extra_options.get("scale_factor", None) is not None:
235
+ result *= np.float32(self.extra_options["scale_factor"]) # in-place for pycuda
236
+ self.cuda.recover_arrays_references(["sinos", "output"])
237
+ return result
238
+
239
+ def __del__(self):
240
+ if getattr(self, "_alg_id", None) is not None:
241
+ astra.algorithm.delete(self._alg_id)
242
+ if getattr(self, "_vol_id", None) is not None:
243
+ astra.data3d.delete(self._vol_id)
244
+ if getattr(self, "_proj_id", None) is not None:
245
+ astra.data3d.delete(self._proj_id)
@@ -86,6 +86,13 @@ class CudaBackprojector(BackprojectorBase):
86
86
  self.sino_mult = CudaSinoMult(self.sino_shape, self.rot_center, ctx=self._processing.ctx)
87
87
  self._prepare_textures() # has to be done after compilation for Cuda (to bind texture to built kernel)
88
88
 
89
+ def _get_filter_init_extra_options(self):
90
+ return {
91
+ "cuda_options": {
92
+ "ctx": self._processing.ctx,
93
+ },
94
+ }
95
+
89
96
  def _transfer_to_texture(self, sino, do_checks=True):
90
97
  if do_checks and not (sino.flags.c_contiguous):
91
98
  raise ValueError("Expected C-Contiguous array")
@@ -7,6 +7,19 @@ from .sinogram import SinoMult
7
7
  from .sinogram import get_extended_sinogram_width
8
8
 
9
9
 
10
+ def rot_center_is_in_middle_of_roi(rot_center, roi, tol=2.0):
11
+ # NB. tolerance should be at least 2,
12
+ # because in halftomo the extended sinogram width is 2*sino_width - int(2 * XXXX)
13
+ # (where XXX depends on whether the CoR is on the left or on the right)
14
+ # because of the int(2 * stuff), we can have a jump of at most two pixels.
15
+ #
16
+ start_x, end_x, start_y, end_y = roi
17
+ return (
18
+ abs((start_x + end_x - 1) / 2 - rot_center) - 0.5 < tol
19
+ and abs((start_y + end_y - 1) / 2 - rot_center) - 0.5 < tol
20
+ )
21
+
22
+
10
23
  class BackprojectorBase:
11
24
  """
12
25
  Base class for backprojectors.
@@ -162,9 +175,6 @@ class BackprojectorBase:
162
175
  self.axis_pos = self.rot_center
163
176
  self._set_angles(angles, n_angles)
164
177
  self._set_slice_roi(slice_roi)
165
- #
166
- # offset = start - move
167
- # move = 0 if not(centered_axis) else start + (n-1)/2. - c
168
178
  if self.extra_options["centered_axis"]:
169
179
  self.offsets = {
170
180
  "x": self.rot_center - (self.n_x - 1) / 2.0,
@@ -210,6 +220,19 @@ class BackprojectorBase:
210
220
  end_x = convert_index(end_x, self.n_x, self.n_x)
211
221
  end_y = convert_index(end_y, self.n_y, self.n_y)
212
222
  self.slice_shape = (end_y - start_y, end_x - start_x)
223
+ if self.extra_options["centered_axis"] and not (
224
+ rot_center_is_in_middle_of_roi(self.rot_center, (start_x, end_x, start_y, end_y))
225
+ ):
226
+ warnings.warn(
227
+ "Using 'centered_axis' when doing a non-centered ROI reconstruction might have side effects: 'start_xy' and 'end_xy' have a different meaning",
228
+ RuntimeWarning,
229
+ )
230
+ # self.extra_options["centered_axis"] = False
231
+ if self.extra_options.get("clip_outer_circle", False) and (
232
+ start_x > 2 or start_y > 2 or abs(end_y - self.n_y) > 2 or abs(end_y - self.n_y) > 2
233
+ ):
234
+ warnings.warn("clip_outer_circle is not supported when doing RoI reconstruction", RuntimeWarning)
235
+ self.extra_options["clip_outer_circle"] = False
213
236
  self.n_x = self.slice_shape[-1]
214
237
  self.n_y = self.slice_shape[-2]
215
238
  self.offsets = {"x": start_x, "y": start_y}
@@ -239,19 +262,20 @@ class BackprojectorBase:
239
262
  self._axis_correction = np.zeros((1, self.n_angles), dtype=np.float32)
240
263
  self._axis_correction[0, :] = axcorr[:] # pylint: disable=E1136
241
264
 
265
+ def _get_filter_init_extra_options(self):
266
+ return {}
267
+
242
268
  def _init_filter(self, filter_name):
243
269
  self.filter_name = filter_name
244
270
  if filter_name in ["None", "none"]:
245
271
  self.sino_filter = None
246
272
  return
247
- sinofilter_other_kwargs = {}
248
- if self.backend != "numpy":
249
- sinofilter_other_kwargs["%s_options" % self.backend] = {"ctx": self._processing.ctx}
250
- sinofilter_other_kwargs["crop_filtered_data"] = self.extra_options.get("crop_filtered_data", True)
273
+
251
274
  # TODO
252
275
  if not (self.extra_options.get("crop_filtered_data", True)):
253
276
  warnings.warn("crop_filtered_data = False is not supported for FBP yet", RuntimeWarning)
254
277
  #
278
+ sinofilter_other_kwargs = self._get_filter_init_extra_options()
255
279
  self.sino_filter = self.SinoFilterClass(
256
280
  self.sino_shape,
257
281
  filter_name=self.filter_name,
@@ -74,5 +74,13 @@ class OpenCLBackprojector(BackprojectorBase):
74
74
  return
75
75
  return cl.enqueue_copy(self._processing.queue, self._d_sino.data, sino.data)
76
76
 
77
+ def _get_filter_init_extra_options(self):
78
+ return {
79
+ "opencl_options": {
80
+ "ctx": self._processing.ctx,
81
+ "queue": self._processing.queue, # !!!!
82
+ },
83
+ }
84
+
77
85
  def _set_kernel_slice_arg(self, d_slice):
78
86
  self.kern_proj_args[1] = d_slice
@@ -33,6 +33,8 @@ class OpenCLSinoFilter(SinoFilter):
33
33
  crop_filtered_data=crop_filtered_data,
34
34
  extra_options=extra_options,
35
35
  )
36
+ if not (crop_filtered_data):
37
+ raise NotImplementedError # TODO
36
38
  self._init_kernels()
37
39
 
38
40
  def _init_fft(self):