nabu 2025.1.0.dev13__py3-none-any.whl → 2025.1.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. nabu/__init__.py +1 -1
  2. nabu/app/cast_volume.py +12 -1
  3. nabu/app/cli_configs.py +81 -4
  4. nabu/app/estimate_motion.py +54 -0
  5. nabu/app/multicor.py +2 -4
  6. nabu/app/pcaflats.py +116 -0
  7. nabu/app/reconstruct.py +1 -7
  8. nabu/app/reduce_dark_flat.py +5 -2
  9. nabu/estimation/cor.py +1 -1
  10. nabu/estimation/motion.py +557 -0
  11. nabu/estimation/tests/test_motion_estimation.py +471 -0
  12. nabu/estimation/tilt.py +1 -1
  13. nabu/estimation/translation.py +47 -1
  14. nabu/io/cast_volume.py +94 -13
  15. nabu/io/reader.py +32 -1
  16. nabu/io/tests/test_remove_volume.py +152 -0
  17. nabu/pipeline/config_validators.py +42 -43
  18. nabu/pipeline/estimators.py +255 -0
  19. nabu/pipeline/fullfield/chunked.py +67 -43
  20. nabu/pipeline/fullfield/chunked_cuda.py +5 -2
  21. nabu/pipeline/fullfield/nabu_config.py +17 -11
  22. nabu/pipeline/fullfield/processconfig.py +8 -2
  23. nabu/pipeline/fullfield/reconstruction.py +3 -0
  24. nabu/pipeline/params.py +12 -0
  25. nabu/pipeline/tests/test_estimators.py +240 -3
  26. nabu/preproc/ccd.py +53 -3
  27. nabu/preproc/flatfield.py +306 -1
  28. nabu/preproc/shift.py +3 -1
  29. nabu/preproc/tests/test_pcaflats.py +154 -0
  30. nabu/processing/rotation_cuda.py +3 -1
  31. nabu/processing/tests/test_rotation.py +4 -2
  32. nabu/reconstruction/fbp.py +7 -0
  33. nabu/reconstruction/fbp_base.py +31 -7
  34. nabu/reconstruction/fbp_opencl.py +8 -0
  35. nabu/reconstruction/filtering_opencl.py +2 -0
  36. nabu/reconstruction/mlem.py +51 -14
  37. nabu/reconstruction/tests/test_filtering.py +13 -2
  38. nabu/reconstruction/tests/test_mlem.py +91 -62
  39. nabu/resources/dataset_analyzer.py +144 -20
  40. nabu/resources/nxflatfield.py +101 -35
  41. nabu/resources/tests/test_nxflatfield.py +1 -1
  42. nabu/resources/utils.py +16 -10
  43. nabu/stitching/alignment.py +7 -7
  44. nabu/stitching/config.py +22 -20
  45. nabu/stitching/definitions.py +2 -2
  46. nabu/stitching/overlap.py +4 -4
  47. nabu/stitching/sample_normalization.py +5 -5
  48. nabu/stitching/stitcher/post_processing.py +5 -3
  49. nabu/stitching/stitcher/pre_processing.py +24 -20
  50. nabu/stitching/tests/test_config.py +3 -3
  51. nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
  52. nabu/stitching/tests/test_z_postprocessing_stitching.py +2 -2
  53. nabu/stitching/tests/test_z_preprocessing_stitching.py +23 -20
  54. nabu/stitching/utils/utils.py +7 -7
  55. nabu/testutils.py +1 -4
  56. nabu/utils.py +13 -0
  57. {nabu-2025.1.0.dev13.dist-info → nabu-2025.1.0rc1.dist-info}/METADATA +3 -4
  58. {nabu-2025.1.0.dev13.dist-info → nabu-2025.1.0rc1.dist-info}/RECORD +62 -57
  59. {nabu-2025.1.0.dev13.dist-info → nabu-2025.1.0rc1.dist-info}/WHEEL +1 -1
  60. {nabu-2025.1.0.dev13.dist-info → nabu-2025.1.0rc1.dist-info}/entry_points.txt +2 -1
  61. nabu/app/correct_rot.py +0 -62
  62. {nabu-2025.1.0.dev13.dist-info → nabu-2025.1.0rc1.dist-info}/licenses/LICENSE +0 -0
  63. {nabu-2025.1.0.dev13.dist-info → nabu-2025.1.0rc1.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,19 @@ from .sinogram import SinoMult
7
7
  from .sinogram import get_extended_sinogram_width
8
8
 
9
9
 
10
+ def rot_center_is_in_middle_of_roi(rot_center, roi, tol=2.0):
11
+ # NB. tolerance should be at least 2,
12
+ # because in halftomo the extended sinogram width is 2*sino_width - int(2 * XXXX)
13
+ # (where XXX depends on whether the CoR is on the left or on the right)
14
+ # because of the int(2 * stuff), we can have a jump of at most two pixels.
15
+ #
16
+ start_x, end_x, start_y, end_y = roi
17
+ return (
18
+ abs((start_x + end_x - 1) / 2 - rot_center) - 0.5 < tol
19
+ and abs((start_y + end_y - 1) / 2 - rot_center) - 0.5 < tol
20
+ )
21
+
22
+
10
23
  class BackprojectorBase:
11
24
  """
12
25
  Base class for backprojectors.
@@ -162,9 +175,6 @@ class BackprojectorBase:
162
175
  self.axis_pos = self.rot_center
163
176
  self._set_angles(angles, n_angles)
164
177
  self._set_slice_roi(slice_roi)
165
- #
166
- # offset = start - move
167
- # move = 0 if not(centered_axis) else start + (n-1)/2. - c
168
178
  if self.extra_options["centered_axis"]:
169
179
  self.offsets = {
170
180
  "x": self.rot_center - (self.n_x - 1) / 2.0,
@@ -210,6 +220,19 @@ class BackprojectorBase:
210
220
  end_x = convert_index(end_x, self.n_x, self.n_x)
211
221
  end_y = convert_index(end_y, self.n_y, self.n_y)
212
222
  self.slice_shape = (end_y - start_y, end_x - start_x)
223
+ if self.extra_options["centered_axis"] and not (
224
+ rot_center_is_in_middle_of_roi(self.rot_center, (start_x, end_x, start_y, end_y))
225
+ ):
226
+ warnings.warn(
227
+ "Using 'centered_axis' when doing a non-centered ROI reconstruction might have side effects: 'start_xy' and 'end_xy' have a different meaning",
228
+ RuntimeWarning,
229
+ )
230
+ # self.extra_options["centered_axis"] = False
231
+ if self.extra_options.get("clip_outer_circle", False) and (
232
+ start_x > 2 or start_y > 2 or abs(end_y - self.n_y) > 2 or abs(end_y - self.n_y) > 2
233
+ ):
234
+ warnings.warn("clip_outer_circle is not supported when doing RoI reconstruction", RuntimeWarning)
235
+ self.extra_options["clip_outer_circle"] = False
213
236
  self.n_x = self.slice_shape[-1]
214
237
  self.n_y = self.slice_shape[-2]
215
238
  self.offsets = {"x": start_x, "y": start_y}
@@ -239,19 +262,20 @@ class BackprojectorBase:
239
262
  self._axis_correction = np.zeros((1, self.n_angles), dtype=np.float32)
240
263
  self._axis_correction[0, :] = axcorr[:] # pylint: disable=E1136
241
264
 
265
+ def _get_filter_init_extra_options(self):
266
+ return {}
267
+
242
268
  def _init_filter(self, filter_name):
243
269
  self.filter_name = filter_name
244
270
  if filter_name in ["None", "none"]:
245
271
  self.sino_filter = None
246
272
  return
247
- sinofilter_other_kwargs = {}
248
- if self.backend != "numpy":
249
- sinofilter_other_kwargs["%s_options" % self.backend] = {"ctx": self._processing.ctx}
250
- sinofilter_other_kwargs["crop_filtered_data"] = self.extra_options.get("crop_filtered_data", True)
273
+
251
274
  # TODO
252
275
  if not (self.extra_options.get("crop_filtered_data", True)):
253
276
  warnings.warn("crop_filtered_data = False is not supported for FBP yet", RuntimeWarning)
254
277
  #
278
+ sinofilter_other_kwargs = self._get_filter_init_extra_options()
255
279
  self.sino_filter = self.SinoFilterClass(
256
280
  self.sino_shape,
257
281
  filter_name=self.filter_name,
@@ -74,5 +74,13 @@ class OpenCLBackprojector(BackprojectorBase):
74
74
  return
75
75
  return cl.enqueue_copy(self._processing.queue, self._d_sino.data, sino.data)
76
76
 
77
+ def _get_filter_init_extra_options(self):
78
+ return {
79
+ "opencl_options": {
80
+ "ctx": self._processing.ctx,
81
+ "queue": self._processing.queue, # !!!!
82
+ },
83
+ }
84
+
77
85
  def _set_kernel_slice_arg(self, d_slice):
78
86
  self.kern_proj_args[1] = d_slice
@@ -33,6 +33,8 @@ class OpenCLSinoFilter(SinoFilter):
33
33
  crop_filtered_data=crop_filtered_data,
34
34
  extra_options=extra_options,
35
35
  )
36
+ if not (crop_filtered_data):
37
+ raise NotImplementedError # TODO
36
38
  self._init_kernels()
37
39
 
38
40
  def _init_fft(self):
@@ -12,6 +12,28 @@ except ImportError:
12
12
  class MLEMReconstructor:
13
13
  """
14
14
  A reconstructor for MLEM reconstruction using the CorrCT toolbox.
15
+
16
+ Parameters
17
+ ----------
18
+ data_vwu_shape : tuple
19
+ Shape of the input data, expected to be (n_slices, n_angles, n_dets). Raises an error if the shape is not 3D.
20
+ angles_rad : numpy.ndarray
21
+ Angles in radians for the projections. Must match the second dimension of `data_vwu_shape`.
22
+ shifts_vu : numpy.ndarray, optional.
23
+ Shifts in the v and u directions for each angle. If provided, must have the same number of cols as `angles_rad`. Each col is (tv,tu)
24
+ cor : float, optional
25
+ Center of rotation, which will be adjusted based on the sinogram width.
26
+ n_iterations : int, optional
27
+ Number of iterations for the MLEM algorithm. Default is 50.
28
+ extra_options : dict, optional
29
+ Additional options for the reconstruction process. Default options include:
30
+ - scale_factor (float, default is 1.0): Scale factor for the reconstruction.
31
+ - compute_shifts (boolean, default is False): Whether to compute shifts.
32
+ - tomo_consistency (boolean, default is False): Whether to enforce tomographic consistency.
33
+ - v_min_for_v_shifts (number, default is 0): Minimum value for vertical shifts.
34
+ - v_max_for_v_shifts (number, default is None): Maximum value for vertical shifts.
35
+ - v_min_for_u_shifts (number, default is 0): Minimum value for horizontal shifts.
36
+ - v_max_for_u_shifts (number, default is None): Maximum value for horizontal shifts.
15
37
  """
16
38
 
17
39
  default_extra_options = {
@@ -21,14 +43,21 @@ class MLEMReconstructor:
21
43
  "v_max_for_v_shifts": None,
22
44
  "v_min_for_u_shifts": 0,
23
45
  "v_max_for_u_shifts": None,
46
+ "scale_factor": 1.0,
47
+ "centered_axis": False,
48
+ "clip_outer_circle": False,
49
+ "outer_circle_value": 0.0,
50
+ "filter_cutoff": 1.0,
51
+ "padding_mode": None,
52
+ "crop_filtered_data": True,
24
53
  }
25
54
 
26
55
  def __init__(
27
56
  self,
28
- sinos_shape,
57
+ data_vwu_shape,
29
58
  angles_rad,
30
59
  shifts_uv=None,
31
- cor=None,
60
+ cor=None, # absolute
32
61
  n_iterations=50,
33
62
  extra_options=None,
34
63
  ):
@@ -36,9 +65,10 @@ class MLEMReconstructor:
36
65
  raise ImportError("Need corrct package")
37
66
  self.angles_rad = angles_rad
38
67
  self.n_iterations = n_iterations
68
+ self.scale_factor = extra_options.get("scale_factor", 1.0)
39
69
 
40
70
  self._configure_extra_options(extra_options)
41
- self._set_sino_shape(sinos_shape)
71
+ self._set_sino_shape(data_vwu_shape)
42
72
  self._set_shifts(shifts_uv, cor)
43
73
 
44
74
  def _configure_extra_options(self, extra_options):
@@ -57,17 +87,22 @@ class MLEMReconstructor:
57
87
 
58
88
  def _set_shifts(self, shifts_uv, cor):
59
89
  if shifts_uv is None:
60
- self.shifts_uv = np.zeros([self.n_angles, 2])
90
+ self.shifts_vu = None
61
91
  else:
62
92
  if shifts_uv.shape[0] != self.n_angles:
63
93
  raise ValueError(
64
94
  f"Number of shifts given ({shifts_uv.shape[0]}) does not mathc the number of projections ({self.n_angles})."
65
95
  )
66
- self.shifts_uv = shifts_uv.copy()
67
- self.cor = cor
96
+ self.shifts_vu = -shifts_uv.copy().T[::-1]
97
+ if cor is None:
98
+ self.cor = 0.0
99
+ else:
100
+ self.cor = (
101
+ -cor + (self.sinos_shape[-1] - 1) / 2.0
102
+ ) # convert absolute to relative in the ASTRA convention, which is opposite to Nabu relative convention.
68
103
 
69
104
  def reset_rot_center(self, cor):
70
- self.cor = cor
105
+ self.cor = -cor + (self.sinos_shape[-1] - 1) / 2.0
71
106
 
72
107
  def reconstruct(self, data_vwu):
73
108
  """
@@ -78,14 +113,17 @@ class MLEMReconstructor:
78
113
  """
79
114
  if not isinstance(data_vwu, np.ndarray):
80
115
  data_vwu = data_vwu.get()
81
- data_vwu /= data_vwu.mean()
116
+ # data_vwu /= data_vwu.mean()
82
117
 
83
118
  # MLEM recons
84
119
  self.vol_geom_align = cct.models.VolumeGeometry.get_default_from_data(data_vwu)
85
- self.prj_geom_align = cct.models.ProjectionGeometry.get_default_parallel()
86
- # Vertical shifts were handled in pipeline. Set them to ZERO
87
- self.shifts_uv[:, 1] = 0.0
88
- self.prj_geom_align.set_detector_shifts_vu(self.shifts_uv.T[::-1])
120
+ if self.shifts_vu is not None:
121
+ self.prj_geom_align = cct.models.ProjectionGeometry.get_default_parallel()
122
+ # Vertical shifts were handled in pipeline. Set them to ZERO
123
+ self.shifts_vu[:, 0] = 0.0
124
+ self.prj_geom_align.set_detector_shifts_vu(self.shifts_vu, self.cor)
125
+ else:
126
+ self.prj_geom_align = None
89
127
 
90
128
  variances_align = cct.processing.compute_variance_poisson(data_vwu)
91
129
  self.weights_align = cct.processing.compute_variance_weight(variances_align, normalized=True) # , use_std=True
@@ -97,5 +135,4 @@ class MLEMReconstructor:
97
135
  self.vol_geom_align, self.angles_rad, rot_axis_shift_pix=self.cor, prj_geom=self.prj_geom_align
98
136
  ) as A:
99
137
  rec, _ = solver(A, data_vwu, iterations=self.n_iterations, **self.solver_opts)
100
-
101
- return rec
138
+ return rec * self.scale_factor
@@ -106,7 +106,11 @@ class TestSinoFilter:
106
106
  assert id(res) == id(output), "when providing output, return value must not change"
107
107
 
108
108
  ref = filter_sinogram(
109
- h_sino, sino_filter.dwidth_padded, filter_name=config["filter_name"], padding_mode=config["padding_mode"]
109
+ h_sino,
110
+ sino_filter.dwidth_padded,
111
+ filter_name=config["filter_name"],
112
+ padding_mode=config["padding_mode"],
113
+ crop_filtered_data=config["crop_filtered_data"],
110
114
  )
111
115
 
112
116
  assert np.allclose(res.get(), ref, atol=6e-5), "test_cuda_filter: something wrong with config=%s" % (
@@ -118,6 +122,8 @@ class TestSinoFilter:
118
122
  )
119
123
  @pytest.mark.parametrize("config", tests_scenarios)
120
124
  def test_opencl_filter(self, config):
125
+ if not (config["crop_filtered_data"]):
126
+ pytest.skip("crop_filtered_data=False is not supported for OpenCL backend yet")
121
127
  sino = self.sino_cl if not (config["truncated_sino"]) else self.sino_truncated_cl
122
128
  h_sino = self.sino if not (config["truncated_sino"]) else self.sino_truncated
123
129
 
@@ -126,6 +132,7 @@ class TestSinoFilter:
126
132
  filter_name=config["filter_name"],
127
133
  padding_mode=config["padding_mode"],
128
134
  opencl_options={"ctx": self.cl.ctx},
135
+ crop_filtered_data=config["crop_filtered_data"],
129
136
  )
130
137
  if config["output_provided"]:
131
138
  output = parray.zeros(self.cl.queue, sino.shape, "f")
@@ -136,7 +143,11 @@ class TestSinoFilter:
136
143
  assert id(res) == id(output), "when providing output, return value must not change"
137
144
 
138
145
  ref = filter_sinogram(
139
- h_sino, sino_filter.dwidth_padded, filter_name=config["filter_name"], padding_mode=config["padding_mode"]
146
+ h_sino,
147
+ sino_filter.dwidth_padded,
148
+ filter_name=config["filter_name"],
149
+ padding_mode=config["padding_mode"],
150
+ crop_filtered_data=config["crop_filtered_data"],
140
151
  )
141
152
 
142
153
  assert np.allclose(res.get(), ref, atol=6e-5), "test_opencl_filter: something wrong with config=%s" % (
@@ -9,84 +9,113 @@ from nabu.reconstruction.mlem import MLEMReconstructor, __have_corrct__
9
9
  @pytest.fixture(scope="class")
10
10
  def bootstrap(request):
11
11
  cls = request.cls
12
- datafile = get_data("sl_mlem.npz")
13
- cls.data = datafile["data"]
12
+ datafile = get_data("test_mlem.npz")
13
+ cls.data_wvu = datafile["data_wvu"]
14
14
  cls.angles_rad = datafile["angles_rad"]
15
- cls.random_u_shifts = datafile["random_u_shifts"]
16
- cls.ref_rec_noshifts = datafile["ref_rec_noshifts"]
17
- cls.ref_rec_shiftsu = datafile["ref_rec_shiftsu"]
18
- cls.ref_rec_u_rand = datafile["ref_rec_u_rand"]
19
- cls.ref_rec_shiftsv = datafile["ref_rec_shiftsv"]
20
- # cls.ref_rec_v_rand = datafile["ref_rec_v_rand"]
21
- cls.tol = 2e-4
15
+ cls.pixel_size_cm = datafile["pixel_size"] * 1e4 # pixel_size originally in um
16
+ cls.true_cor = datafile["true_cor"]
17
+ cls.mlem_cor_None_nosh = datafile["mlem_cor_None_nosh"]
18
+ cls.mlem_cor_truecor_nosh = datafile["mlem_cor_truecor_nosh"]
19
+ cls.mlem_cor_truecor_shifts_v0 = datafile["mlem_cor_truecor_shifts_v0"]
20
+ cls.shifts_uv_v0 = datafile["shifts_uv_v0"]
21
+ cls.shifts_uv = datafile["shifts_uv"]
22
+
23
+ cls.tol = 1.3e-4
22
24
 
23
25
 
24
26
  @pytest.mark.skipif(not (__has_pycuda__ and __have_corrct__), reason="Need pycuda and corrct for this test")
25
27
  @pytest.mark.usefixtures("bootstrap")
26
- class TestMLEM:
28
+ class TestMLEMReconstructor:
27
29
  """These tests test the general MLEM reconstruction algorithm
28
30
  and the behavior of the reconstruction with respect to horizontal shifts.
29
31
  Only horizontal shifts are tested here because vertical shifts are handled outside
30
32
  the reconstruction object, but in the embedding reconstruction pipeline. See FullFieldReconstructor
33
+ It is compared against a reference reconstruction generated with the `rec_mlem` function
34
+ defined in the `generate_test_data.py` script.
31
35
  """
32
36
 
33
- def _create_MLEM_reconstructor(self, shifts_uv=None):
34
- return MLEMReconstructor(
35
- self.data.shape, -self.angles_rad, shifts_uv, cor=0.0, n_iterations=10 # mind the sign
37
+ def _rec_mlem(self, cor, shifts_uv, data_wvu, angles_rad):
38
+ n_angles, n_z, n_x = data_wvu.shape
39
+
40
+ mlem = MLEMReconstructor(
41
+ (n_z, n_angles, n_x),
42
+ angles_rad,
43
+ shifts_uv=shifts_uv,
44
+ cor=cor,
45
+ n_iterations=50,
46
+ extra_options={"centered_axis": True, "clip_outer_circle": True, "scale_factor": 1 / self.pixel_size_cm},
36
47
  )
48
+ rec_mlem = mlem.reconstruct(data_wvu.swapaxes(0, 1))
49
+ return rec_mlem
37
50
 
38
- def test_simple_mlem_recons(self):
39
- R = self._create_MLEM_reconstructor()
40
- rec = R.reconstruct(self.data)
41
- delta = np.abs(rec[:, ::-1, :] - self.ref_rec_noshifts)
51
+ def test_simple_mlem_recons_cor_None_nosh(self):
52
+ slice_index = 25
53
+ rec = self._rec_mlem(None, None, self.data_wvu, self.angles_rad)[slice_index]
54
+ delta = np.abs(rec - self.mlem_cor_None_nosh)
42
55
  assert np.max(delta) < self.tol
43
56
 
44
- def test_mlem_recons_with_u_shifts(self):
45
- shifts = np.zeros((len(self.angles_rad), 2))
46
- shifts[:, 0] = -5
47
- R = self._create_MLEM_reconstructor(shifts)
48
- rec = R.reconstruct(self.data)
49
- delta = np.abs(rec[:, ::-1] - self.ref_rec_shiftsu)
50
- assert np.max(delta) < self.tol
57
+ def test_simple_mlem_recons_cor_truecor_nosh(self):
58
+ slice_index = 25
59
+ rec = self._rec_mlem(self.true_cor, None, self.data_wvu, self.angles_rad)[slice_index]
60
+ delta = np.abs(rec - self.mlem_cor_truecor_nosh)
61
+ assert np.max(delta) < 2.6e-4
51
62
 
52
- def test_mlem_recons_with_random_u_shifts(self):
53
- R = self._create_MLEM_reconstructor(self.random_u_shifts)
54
- rec = R.reconstruct(self.data)
55
- delta = np.abs(rec[:, ::-1] - self.ref_rec_u_rand)
56
- assert np.max(delta) < self.tol
63
+ def test_compare_with_fbp(self):
64
+ from nabu.reconstruction.fbp import Backprojector
65
+
66
+ def _rec_fbp(cor, shifts_uv, data_wvu, angles_rad):
67
+ n_angles, n_z, n_x = data_wvu.shape
57
68
 
58
- def test_mlem_recons_with_constant_v_shifts(self):
59
- from nabu.preproc.shift import VerticalShift
69
+ if shifts_uv is None:
70
+ fbp = Backprojector(
71
+ (n_angles, n_x),
72
+ angles=angles_rad,
73
+ rot_center=cor,
74
+ halftomo=False,
75
+ padding_mode="edges",
76
+ extra_options={
77
+ "centered_axis": True,
78
+ "clip_outer_circle": True,
79
+ "scale_factor": 1 / self.pixel_size_cm,
80
+ },
81
+ )
82
+ else:
83
+ fbp = Backprojector(
84
+ (n_angles, n_x),
85
+ angles=angles_rad,
86
+ rot_center=cor,
87
+ halftomo=False,
88
+ padding_mode="edges",
89
+ extra_options={
90
+ "centered_axis": True,
91
+ "clip_outer_circle": True,
92
+ "scale_factor": 1 / self.pixel_size_cm, # convert um to cm
93
+ "axis_correction": shifts_uv[:, 0],
94
+ },
95
+ )
60
96
 
97
+ rec_fbp = np.zeros((n_z, n_x, n_x), "f")
98
+ for i in range(n_z):
99
+ rec_fbp[i] = fbp.fbp(data_wvu[:, i])
100
+
101
+ return rec_fbp
102
+
103
+ fbp = _rec_fbp(self.true_cor, None, self.data_wvu, self.angles_rad)[25]
104
+ mlem = self._rec_mlem(self.true_cor, None, self.data_wvu, self.angles_rad)[25]
105
+ delta = np.abs(fbp - mlem)
106
+ assert (
107
+ np.max(delta) < 400
108
+ ) # These two should not be really equal. But the test should test that both algo FBP and MLEM behave similarly.
109
+
110
+ def test_mlem_zeroshifts_equal_noshifts(self):
61
111
  shifts = np.zeros((len(self.angles_rad), 2))
62
- shifts[:, 1] = -20
63
-
64
- nv, n_angles, nu = self.data.shape
65
- radios_movements = VerticalShift(
66
- (n_angles, nv, nu), -shifts[:, 1]
67
- ) # Minus sign here mimics what is done in the pipeline.
68
- tmp_in = np.swapaxes(self.data, 0, 1).copy()
69
- tmp_out = np.zeros_like(tmp_in)
70
- radios_movements.apply_vertical_shifts(tmp_in, list(range(n_angles)), output=tmp_out)
71
- data = np.swapaxes(tmp_out, 0, 1).copy()
72
-
73
- R = self._create_MLEM_reconstructor(shifts)
74
- rec = R.reconstruct(data)
75
-
76
- axslice = 120
77
- trslice = 84
78
- axslice1 = self.ref_rec_shiftsv[axslice]
79
- axslice2 = rec[axslice, ::-1]
80
- trslice1 = self.ref_rec_shiftsv[trslice]
81
- trslice2 = rec[trslice, ::-1]
82
- # delta = np.abs(rec[:, ::-1] - self.ref_rec_shiftsv)
83
- delta_ax = np.abs(axslice1 - axslice2)
84
- delta_tr = np.abs(trslice1 - trslice2)
85
- assert max(np.max(delta_ax), np.max(delta_tr)) < self.tol
86
-
87
- @pytest.mark.skip(reason="No valid reference reconstruction for this test.")
88
- def test_mlem_recons_with_random_v_shifts(self):
89
- """NOT YET IMPLEMENTED.
90
- This is a temporary version due to unpexcted behavior of CorrCT/Astra to
91
- compute a reference implementation. See [question on Astra's github](https://github.com/astra-toolbox/astra-toolbox/discussions/520).
92
- """
112
+ rec_nosh = self._rec_mlem(self.true_cor, None, self.data_wvu, self.angles_rad)
113
+ rec_zerosh = self._rec_mlem(self.true_cor, shifts, self.data_wvu, self.angles_rad)
114
+ delta = np.abs(rec_nosh - rec_zerosh)
115
+ assert np.max(delta) < self.tol
116
+
117
+ def test_mlem_recons_with_u_shifts(self):
118
+ slice_index = 25
119
+ rec = self._rec_mlem(self.true_cor, self.shifts_uv_v0, self.data_wvu, self.angles_rad)[slice_index]
120
+ delta = np.abs(rec - self.mlem_cor_truecor_shifts_v0)
121
+ assert np.max(delta) < self.tol
@@ -1,3 +1,4 @@
1
+ from enum import Enum
1
2
  import os
2
3
  import numpy as np
3
4
  from silx.io.url import DataUrl
@@ -7,7 +8,7 @@ from tomoscan.esrf.scan.edfscan import EDFTomoScan
7
8
  from tomoscan.esrf.scan.nxtomoscan import NXtomoScan
8
9
  from packaging.version import parse as parse_version
9
10
 
10
- from ..utils import BaseClassError, check_supported, indices_to_slices
11
+ from ..utils import BaseClassError, check_supported, indices_to_slices, is_scalar, search_sorted
11
12
  from ..io.reader import EDFStackReader, NXDarksFlats, NXTomoReader
12
13
  from ..io.utils import get_compacted_dataslices
13
14
  from .utils import get_values_from_file, is_hdf5_extension
@@ -16,6 +17,33 @@ from .logger import LoggerOrPrint
16
17
  from ..pipeline.utils import nabu_env_settings
17
18
 
18
19
 
20
+ # We could import the 1000+ LoC nxtomo.nxobject.nxdetector.ImageKey... or we can do this
21
+ class ImageKey(Enum):
22
+ ALIGNMENT = -1
23
+ PROJECTION = 0
24
+ FLAT_FIELD = 1
25
+ DARK_FIELD = 2
26
+ INVALID = 3
27
+
28
+
29
+ # ---
30
+
31
+ _image_type = {
32
+ "projections": ImageKey.PROJECTION.value,
33
+ "projection": ImageKey.PROJECTION.value,
34
+ "radios": ImageKey.PROJECTION.value,
35
+ "radio": ImageKey.PROJECTION.value,
36
+ "flats": ImageKey.FLAT_FIELD.value,
37
+ "flat": ImageKey.FLAT_FIELD.value,
38
+ "darks": ImageKey.DARK_FIELD.value,
39
+ "dark": ImageKey.DARK_FIELD.value,
40
+ "static": ImageKey.ALIGNMENT.value,
41
+ "alignment": ImageKey.ALIGNMENT.value,
42
+ "return": ImageKey.ALIGNMENT.value,
43
+ "invalid": ImageKey.INVALID.value,
44
+ }
45
+
46
+
19
47
  class DatasetAnalyzer:
20
48
  _scanner = None
21
49
  kind = "none"
@@ -285,6 +313,24 @@ class DatasetAnalyzer:
285
313
  def scan_dirname(self):
286
314
  raise BaseClassError
287
315
 
316
+ def get_alignment_projections(self, image_sub_region=None):
317
+ raise NotImplementedError
318
+
319
+ @property
320
+ def all_angles(self):
321
+ raise NotImplementedError
322
+
323
+ def get_frame(self, idx): ...
324
+
325
+ @property
326
+ def is_360(self):
327
+ """
328
+ Return True iff the scan is 360 degrees (regardless of half-tomo mode)
329
+ """
330
+ angles = self.rotation_angles
331
+ d_theta = angles[1] - angles[0]
332
+ return np.isclose(angles.max() - angles.min(), 2 * np.pi, atol=2 * d_theta)
333
+
288
334
 
289
335
  class EDFDatasetAnalyzer(DatasetAnalyzer):
290
336
  """
@@ -358,9 +404,6 @@ class HDF5DatasetAnalyzer(DatasetAnalyzer):
358
404
 
359
405
  _scanner = NXtomoScan
360
406
  kind = "nx"
361
- # We could import the 1000+ LoC nxtomo.nxobject.nxdetector.ImageKey... or we can do this
362
- _image_key_value = {"flats": 1, "darks": 2, "radios": 0}
363
- #
364
407
 
365
408
  @property
366
409
  def z_translation(self):
@@ -442,7 +485,8 @@ class HDF5DatasetAnalyzer(DatasetAnalyzer):
442
485
  slices = set()
443
486
  for du in get_compacted_dataslices(images).values():
444
487
  if du.data_slice() is not None:
445
- s = (du.data_slice().start, du.data_slice().stop)
488
+ # note: du.data_slice is a uint in recent tomoscan version
489
+ s = (int(du.data_slice().start), int(du.data_slice().stop))
446
490
  else:
447
491
  s = None
448
492
  slices.add(s)
@@ -452,7 +496,7 @@ class HDF5DatasetAnalyzer(DatasetAnalyzer):
452
496
  def _select_according_to_frame_type(self, data, frame_type):
453
497
  if data is None:
454
498
  return None
455
- return data[self.dataset_scanner.image_key_control == self._image_key_value[frame_type]]
499
+ return data[self.dataset_scanner.image_key_control == _image_type[frame_type]]
456
500
 
457
501
  def get_reduced_flats(self, method="median", force_reload=False, **reader_kwargs):
458
502
  dkrf_reader = NXDarksFlats(
@@ -475,9 +519,7 @@ class HDF5DatasetAnalyzer(DatasetAnalyzer):
475
519
  For example, if the dataset flats are located at indices [1, 2, ..., 99], then
476
520
  frame_slices("flats") will return [slice(0, 100)].
477
521
  """
478
- return indices_to_slices(
479
- np.where(self.dataset_scanner.image_key_control == self._image_key_value[frame_type])[0]
480
- )
522
+ return indices_to_slices(np.where(self.dataset_scanner.image_key_control == _image_type[frame_type])[0])
481
523
 
482
524
  def get_reader(self, **kwargs):
483
525
  return NXTomoReader(self.dataset_hdf5_url.file_path(), data_path=self.dataset_hdf5_url.data_path(), **kwargs)
@@ -492,18 +534,87 @@ class HDF5DatasetAnalyzer(DatasetAnalyzer):
492
534
  # os.path.dirname(di.dataset_hdf5_url.file_path())
493
535
  return self.dataset_scanner.path
494
536
 
537
+ def get_alignment_projections(self, image_sub_region=None):
538
+ """
539
+ Get the extra projections (if any) that are used as "reference projections" for alignment.
540
+ For certain scan, when completing a (half) turn, sometimes extra projections are acquired for alignment purpose.
495
541
 
496
- def analyze_dataset(dataset_path, extra_options=None, logger=None):
497
- if not (os.path.isdir(dataset_path)):
498
- if not (os.path.isfile(dataset_path)):
499
- raise ValueError("Error: %s no such file or directory" % dataset_path)
500
- if not (is_hdf5_extension(os.path.splitext(dataset_path)[-1].replace(".", ""))):
501
- raise ValueError("Error: expected a HDF5 file")
502
- dataset_analyzer_class = HDF5DatasetAnalyzer
503
- else: # directory -> assuming EDF
504
- dataset_analyzer_class = EDFDatasetAnalyzer
505
- dataset_structure = dataset_analyzer_class(dataset_path, extra_options=extra_options, logger=logger)
506
- return dataset_structure
542
+ Returns
543
+ -------
544
+ projs: numpy.ndarray
545
+ Array with shape (n_projections, n_y, n_x)
546
+ angles: numpy.ndarray
547
+ Corresponding angles in degrees
548
+ indices:
549
+ Indices of projections
550
+ """
551
+ sub_region = None
552
+ if image_sub_region is not None:
553
+ sub_region = (None,) + image_sub_region
554
+ reader = self.get_reader(image_key=ImageKey.ALIGNMENT.value, sub_region=sub_region)
555
+ projs = reader.load_data()
556
+ indices = reader.get_frames_indices()
557
+ angles = get_angle_at_index(self.all_angles, indices)
558
+ return projs, angles, indices
559
+
560
+ @property
561
+ def all_angles(self):
562
+ return np.array(self.dataset_scanner.rotation_angle)
563
+
564
+ def get_index_from_angle(self, angle, image_key=0, return_found_angle=False):
565
+ """
566
+ Return the index of the image taken at rotation angle 'angle'.
567
+ By default look at the projections, i.e image_key = 0
568
+ """
569
+ all_angles = self.all_angles
570
+ all_indices = np.arange(len(all_angles))
571
+ all_image_key = self.dataset_scanner.image_key_control
572
+
573
+ idx2 = np.where(all_image_key == image_key)[0]
574
+ angles = all_angles[idx2]
575
+ idx_angles_sorted = np.argsort(angles)
576
+ angles_sorted = angles[idx_angles_sorted]
577
+
578
+ pos = search_sorted(angles_sorted, angle)
579
+ # this gives a position in "idx2", but we need the position in "all_indices"
580
+ idx = all_indices[idx2[idx_angles_sorted[pos]]]
581
+ if return_found_angle:
582
+ return idx, angles_sorted[pos]
583
+ return idx
584
+
585
+ def get_image_at_angle(self, angle_deg, image_type="projection", sub_region=None, return_angle_and_index=False):
586
+ image_key = _image_type[image_type]
587
+ idx, angle_found = self.get_index_from_angle(angle_deg, image_key=image_key, return_found_angle=True)
588
+
589
+ # Option 1:
590
+ if sub_region is None:
591
+ sub_region = (None, None)
592
+ # Convert absolute index to index of image_key
593
+ idx2 = np.searchsorted(np.where(self.dataset_scanner.image_key_control == image_key)[0], idx)
594
+ sub_region = (slice(idx2, idx2 + 1),) + sub_region
595
+ reader = self.get_reader(image_key=image_key, sub_region=sub_region)
596
+ img = reader.load_data()[0]
597
+ if return_angle_and_index:
598
+ return img, angle_found, idx
599
+ return img
600
+
601
+ # Option 2:
602
+ # return self.get_frame(idx)
603
+ # something like:
604
+ # [fr for fr in self.dataset_scanner.frames if fr.image_key.value == 0 and fr.rotation_angle == 180 and fr._is_control_frame is False]
605
+
606
+ def get_frame(self, idx):
607
+ return get_data(self.dataset_scanner.frames[idx].url)
608
+
609
+
610
+ def get_angle_at_index(all_angles, index):
611
+ """
612
+ Return the rotation angle corresponding to image index 'index'
613
+ """
614
+ if is_scalar(index):
615
+ return all_angles[index]
616
+ else:
617
+ return all_angles[np.array(index)]
507
618
 
508
619
 
509
620
  def get_radio_pair(dataset_info, radio_angles: tuple, return_indices=False):
@@ -549,3 +660,16 @@ def get_radio_pair(dataset_info, radio_angles: tuple, return_indices=False):
549
660
  return radios, radios_indices
550
661
  else:
551
662
  return radios
663
+
664
+
665
+ def analyze_dataset(dataset_path, extra_options=None, logger=None):
666
+ if not (os.path.isdir(dataset_path)):
667
+ if not (os.path.isfile(dataset_path)):
668
+ raise ValueError("Error: %s no such file or directory" % dataset_path)
669
+ if not (is_hdf5_extension(os.path.splitext(dataset_path)[-1].replace(".", ""))):
670
+ raise ValueError("Error: expected a HDF5 file")
671
+ dataset_analyzer_class = HDF5DatasetAnalyzer
672
+ else: # directory -> assuming EDF
673
+ dataset_analyzer_class = EDFDatasetAnalyzer
674
+ dataset_structure = dataset_analyzer_class(dataset_path, extra_options=extra_options, logger=logger)
675
+ return dataset_structure