nabu 2025.1.0.dev5__py3-none-any.whl → 2025.1.0.dev12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. doc/doc_config.py +32 -0
  2. nabu/__init__.py +1 -1
  3. nabu/app/double_flatfield.py +18 -5
  4. nabu/app/reconstruct_helical.py +4 -4
  5. nabu/app/stitching.py +7 -2
  6. nabu/cuda/src/backproj.cu +10 -10
  7. nabu/cuda/src/cone.cu +4 -0
  8. nabu/cuda/utils.py +1 -1
  9. nabu/estimation/cor.py +3 -3
  10. nabu/io/cast_volume.py +13 -0
  11. nabu/io/reader.py +3 -2
  12. nabu/opencl/src/backproj.cl +10 -10
  13. nabu/pipeline/estimators.py +6 -6
  14. nabu/pipeline/fullfield/chunked.py +13 -13
  15. nabu/pipeline/fullfield/computations.py +4 -1
  16. nabu/pipeline/fullfield/get_double_flatfield.py +147 -0
  17. nabu/pipeline/fullfield/nabu_config.py +16 -4
  18. nabu/pipeline/fullfield/processconfig.py +22 -2
  19. nabu/pipeline/fullfield/reconstruction.py +9 -4
  20. nabu/pipeline/helical/gridded_accumulator.py +1 -1
  21. nabu/pipeline/helical/helical_reconstruction.py +2 -2
  22. nabu/pipeline/helical/nabu_config.py +1 -1
  23. nabu/pipeline/helical/weight_balancer.py +1 -1
  24. nabu/pipeline/params.py +8 -3
  25. nabu/preproc/shift.py +1 -1
  26. nabu/processing/fft_base.py +6 -2
  27. nabu/processing/fft_cuda.py +23 -4
  28. nabu/processing/fft_opencl.py +19 -2
  29. nabu/processing/padding_cuda.py +0 -1
  30. nabu/processing/processing_base.py +11 -5
  31. nabu/reconstruction/astra.py +245 -0
  32. nabu/reconstruction/cone.py +34 -9
  33. nabu/reconstruction/fbp.py +7 -0
  34. nabu/reconstruction/fbp_base.py +8 -0
  35. nabu/reconstruction/filtering.py +59 -25
  36. nabu/reconstruction/filtering_cuda.py +21 -20
  37. nabu/reconstruction/filtering_opencl.py +8 -14
  38. nabu/reconstruction/hbp.py +10 -10
  39. nabu/reconstruction/rings_cuda.py +41 -13
  40. nabu/reconstruction/tests/test_cone.py +35 -0
  41. nabu/reconstruction/tests/test_fbp.py +32 -11
  42. nabu/reconstruction/tests/test_filtering.py +14 -5
  43. nabu/resources/dataset_analyzer.py +34 -2
  44. nabu/resources/tests/test_extract.py +4 -2
  45. nabu/stitching/config.py +6 -1
  46. nabu/stitching/stitcher/dumper/__init__.py +1 -0
  47. nabu/stitching/stitcher/dumper/postprocessing.py +105 -1
  48. nabu/stitching/stitcher/post_processing.py +14 -4
  49. nabu/stitching/stitcher/pre_processing.py +1 -1
  50. nabu/stitching/stitcher/single_axis.py +8 -7
  51. nabu/stitching/stitcher/z_stitcher.py +8 -4
  52. nabu/stitching/utils/utils.py +2 -2
  53. nabu/testutils.py +2 -2
  54. nabu/utils.py +9 -2
  55. {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev12.dist-info}/METADATA +9 -28
  56. {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev12.dist-info}/RECORD +60 -57
  57. {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev12.dist-info}/WHEEL +1 -1
  58. {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev12.dist-info}/entry_points.txt +0 -0
  59. {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev12.dist-info/licenses}/LICENSE +0 -0
  60. {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev12.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  import os
2
2
  import posixpath
3
3
  import numpy as np
4
+ from .get_double_flatfield import get_double_flatfield
4
5
  from silx.io import get_data
5
6
  from silx.io.url import DataUrl
6
7
  from ...utils import copy_dict_items, compare_dicts
@@ -32,6 +33,7 @@ class ProcessConfig(ProcessConfigBase):
32
33
 
33
34
  (2) update_dataset_info_with_user_config
34
35
  - Update flats/darks
36
+ - Double-flat-field
35
37
  - CoR (value or estimation method) # no estimation yet
36
38
  - rotation angles
37
39
  - translations files
@@ -89,12 +91,29 @@ class ProcessConfig(ProcessConfigBase):
89
91
  self.subsampling_factor = subsampling_factor or 1
90
92
  self.subsampling_start = subsampling_start or 0
91
93
 
94
+ self._get_double_flatfield()
92
95
  self._update_dataset_with_user_overwrites()
93
96
  self._get_rotation_axis_position()
94
97
  self._update_rotation_angles()
95
98
  self._get_translation_file("reconstruction", "translation_movements_file", "translations")
96
99
  self._get_user_sino_normalization()
97
100
 
101
+ def _get_double_flatfield(self):
102
+ self._dff_file = None
103
+ dff_mode = self.nabu_config["preproc"]["double_flatfield"]
104
+ if not (dff_mode):
105
+ return
106
+ self._dff_file = get_double_flatfield(
107
+ self.dataset_info,
108
+ dff_mode,
109
+ output_dir=self.nabu_config["output"]["location"],
110
+ darks_flats_dir=self.nabu_config["dataset"]["darks_flats_dir"],
111
+ dff_options={
112
+ "dff_sigma": self.nabu_config["preproc"]["dff_sigma"],
113
+ "do_flatfield": (self.nabu_config["preproc"]["flatfield"] is not False),
114
+ },
115
+ )
116
+
98
117
  def _update_dataset_with_user_overwrites(self):
99
118
  user_overwrites = self.nabu_config["dataset"]["overwrite_metadata"].strip()
100
119
  if user_overwrites in ("", None):
@@ -451,11 +470,11 @@ class ProcessConfig(ProcessConfigBase):
451
470
  #
452
471
  # Double flat field
453
472
  #
454
- if nabu_config["preproc"]["double_flatfield_enabled"]:
473
+ if nabu_config["preproc"]["double_flatfield"]:
455
474
  tasks.append("double_flatfield")
456
475
  options["double_flatfield"] = {
457
476
  "sigma": nabu_config["preproc"]["dff_sigma"],
458
- "processes_file": nabu_config["preproc"]["processes_file"],
477
+ "processes_file": self._dff_file or nabu_config["preproc"]["processes_file"],
459
478
  "log_min_clip": nabu_config["preproc"]["log_min_clip"],
460
479
  "log_max_clip": nabu_config["preproc"]["log_max_clip"],
461
480
  }
@@ -575,6 +594,7 @@ class ProcessConfig(ProcessConfigBase):
575
594
  "sample_detector_dist",
576
595
  "hbp_legs",
577
596
  "hbp_reduction_steps",
597
+ "crop_filtered_data",
578
598
  ],
579
599
  )
580
600
  rec_options = options["reconstruction"]
@@ -120,7 +120,7 @@ class FullFieldReconstructor:
120
120
  vm = virtual_memory()
121
121
  self.resources["mem_avail_GB"] = vm.available / 1e9
122
122
  # Account for other memory constraints. There might be a better way
123
- slurm_mem_constraint_MB = int(environ.get("SLURM_MEM_PER_NODE", 0))
123
+ slurm_mem_constraint_MB = int(environ.get("SLURM_MEM_PER_NODE", 0)) # noqa: PLW1508
124
124
  if slurm_mem_constraint_MB > 0:
125
125
  self.resources["mem_avail_GB"] = slurm_mem_constraint_MB / 1e3
126
126
  #
@@ -131,8 +131,12 @@ class FullFieldReconstructor:
131
131
  self.resources["gpus"] = avail_gpus
132
132
  if len(avail_gpus) == 0:
133
133
  return
134
- # pick first GPU by default. TODO: handle user's nabu_config["resources"]["gpu_id"]
135
- self.resources["gpu_id"] = self._gpu_id = first_generator_item(avail_gpus.keys())
134
+ user_gpus = self.process_config.nabu_config.get("resources", {}).get("gpu_id", [])
135
+ if len(user_gpus) == 0:
136
+ user_gpus = [0]
137
+ # For now nabu does not support multi-GPU reconstruction. Take the first one.
138
+ user_gpu_idx = user_gpus[0]
139
+ self.resources["gpu_id"] = self._gpu_id = list(avail_gpus.keys())[user_gpu_idx]
136
140
 
137
141
  def _get_backend(self, backend, cuda_options):
138
142
  self._pipeline_cls = ChunkedPipeline
@@ -145,6 +149,7 @@ class FullFieldReconstructor:
145
149
  backend = "numpy"
146
150
  else:
147
151
  self.gpu_mem = self.resources["gpus"][self._gpu_id]["memory_GB"] * self.gpu_mem_fraction
152
+ self.cuda_options = {"device_id": self._gpu_id}
148
153
  if backend == "cuda":
149
154
  if not (__has_pycuda__):
150
155
  raise RuntimeError("pycuda not avilable")
@@ -307,7 +312,7 @@ class FullFieldReconstructor:
307
312
  sigma = opts["unsharp_sigma"]
308
313
  # nabu uses cutoff = 4
309
314
  cutoff = 4
310
- gaussian_kernel_size = int(ceil(2 * cutoff * sigma + 1))
315
+ gaussian_kernel_size = ceil(2 * cutoff * sigma + 1)
311
316
  self.logger.debug("Unsharp mask margin: %d pixels" % gaussian_kernel_size)
312
317
  return (gaussian_kernel_size, gaussian_kernel_size)
313
318
 
@@ -532,7 +532,7 @@ def get_reconstruction_space(span_info, min_scanwise_z, end_scanwise_z, phase_ma
532
532
  # regridded dataset, estimating a meaningul angular step representative
533
533
  # of the raw data
534
534
  my_angle_step = abs(np.diff(span_info.projection_angles_deg).mean())
535
- n_gridded_angles = int(round(360.0 / my_angle_step))
535
+ n_gridded_angles = round(360.0 / my_angle_step)
536
536
 
537
537
  radios_h = phase_margin_pix + (my_z_end - my_z_min) + phase_margin_pix
538
538
 
@@ -168,8 +168,8 @@ class HelicalReconstructorRegridded:
168
168
 
169
169
  # the meaming of z_min and z_max is: position in slices units from the
170
170
  # first available slice and in the direction of the scan
171
- self.z_min = int(round(z_start * (0 - z_fract_min) + z_max * z_fract_min))
172
- self.z_max = int(round(z_start * (0 - z_fract_max) + z_max * z_fract_max)) + 1
171
+ self.z_min = round(z_start * (0 - z_fract_min) + z_max * z_fract_min)
172
+ self.z_max = round(z_start * (0 - z_fract_max) + z_max * z_fract_max) + 1
173
173
 
174
174
  def _compute_translations_margin(self):
175
175
  return 0, 0
@@ -43,7 +43,7 @@ nabu_config["preproc"]["processes_file"] = {
43
43
  "validator": optional_file_location_validator,
44
44
  "type": "required",
45
45
  }
46
- nabu_config["preproc"]["double_flatfield_enabled"]["default"] = 1
46
+ nabu_config["preproc"]["double_flatfield"]["default"] = 1
47
47
 
48
48
 
49
49
  nabu_config["reconstruction"].update(
@@ -83,7 +83,7 @@ def shift(arr, shift, fill_value=0.0):
83
83
  """
84
84
  result = np.zeros_like(arr)
85
85
 
86
- num1 = int(math.floor(shift))
86
+ num1 = math.floor(shift)
87
87
  num2 = num1 + 1
88
88
  partition = shift - num1
89
89
 
nabu/pipeline/params.py CHANGED
@@ -25,12 +25,17 @@ unsharp_methods = {
25
25
  "": None,
26
26
  }
27
27
 
28
+ # see PaddingBase.supported_modes
28
29
  padding_modes = {
29
- "edges": "edge",
30
- "edge": "edge",
31
- "mirror": "mirror",
32
30
  "zeros": "zeros",
33
31
  "zero": "zeros",
32
+ "constant": "zeros",
33
+ "edges": "edge",
34
+ "edge": "edge",
35
+ "mirror": "reflect",
36
+ "reflect": "reflect",
37
+ "symmetric": "symmetric",
38
+ "wrap": "wrap",
34
39
  }
35
40
 
36
41
  reconstruction_methods = {
nabu/preproc/shift.py CHANGED
@@ -42,7 +42,7 @@ class VerticalShift:
42
42
  def _init_interp_coefficients(self):
43
43
  self.interp_infos = []
44
44
  for s in self.shifts:
45
- s0 = int(floor(s))
45
+ s0 = floor(s)
46
46
  f = s - s0
47
47
  self.interp_infos.append([s0, f])
48
48
 
@@ -93,6 +93,10 @@ class _BaseFFT:
93
93
  pass
94
94
 
95
95
 
96
+ def raise_base_class_error(slf, *args, **kwargs):
97
+ raise ValueError
98
+
99
+
96
100
  class _BaseVKFFT(_BaseFFT):
97
101
  """
98
102
  FFT using VKFFT backend
@@ -101,7 +105,7 @@ class _BaseVKFFT(_BaseFFT):
101
105
  implem = "vkfft"
102
106
  backend = "none"
103
107
  ProcessingCls = BaseClassError
104
- vkffs_cls = BaseClassError
108
+ get_fft_obj = raise_base_class_error
105
109
 
106
110
  def _configure_batched_transform(self):
107
111
  if self.axes is not None and len(self.shape) == len(self.axes):
@@ -128,7 +132,7 @@ class _BaseVKFFT(_BaseFFT):
128
132
  self._vkfft_ndim = None
129
133
 
130
134
  def _compute_fft_plans(self):
131
- self._vkfft_plan = self.vkffs_cls(
135
+ self._vkfft_plan = self.get_fft_obj(
132
136
  self.shape,
133
137
  self.dtype,
134
138
  ndim=self._vkfft_ndim,
@@ -1,18 +1,19 @@
1
1
  import os
2
2
  import warnings
3
+ from functools import lru_cache
3
4
  from multiprocessing import get_context
4
5
  from multiprocessing.pool import Pool
5
6
  import numpy as np
6
- from ..utils import check_supported
7
+ from ..utils import BaseClassError, check_supported, no_decorator
7
8
  from .fft_base import _BaseFFT, _BaseVKFFT
8
9
 
9
10
  try:
10
- from pyvkfft.cuda import VkFFTApp as vk_cufft
11
+ from pyvkfft.cuda import VkFFTApp as CudaVkFFTApp
11
12
 
12
13
  __has_vkfft__ = True
13
14
  except (ImportError, OSError):
14
15
  __has_vkfft__ = False
15
- vk_cufft = None
16
+ CudaVkFFTApp = BaseClassError
16
17
  from ..cuda.processing import CudaProcessing
17
18
 
18
19
  Plan = None
@@ -20,6 +21,8 @@ cu_fft = None
20
21
  cu_ifft = None
21
22
  __has_skcuda__ = None
22
23
 
24
+ n_cached_ffts = int(os.getenv("NABU_FFT_CACHE", "0"))
25
+
23
26
 
24
27
  def init_skcuda():
25
28
  # This needs to be done here, because scikit-cuda creates a Cuda context at import,
@@ -146,6 +149,18 @@ class SKCUFFT(_BaseFFT):
146
149
  return output
147
150
 
148
151
 
152
+ maybe_cached = lru_cache(maxsize=n_cached_ffts) if n_cached_ffts > 0 else no_decorator
153
+
154
+
155
+ @maybe_cached
156
+ def _get_vkfft_cuda(*args, **kwargs):
157
+ return CudaVkFFTApp(*args, **kwargs)
158
+
159
+
160
+ def get_vkfft_cuda(slf, *args, **kwargs):
161
+ return _get_vkfft_cuda(*args, **kwargs)
162
+
163
+
149
164
  class VKCUFFT(_BaseVKFFT):
150
165
  """
151
166
  Cuda FFT, using VKFFT backend
@@ -154,7 +169,7 @@ class VKCUFFT(_BaseVKFFT):
154
169
  implem = "vkfft"
155
170
  backend = "cuda"
156
171
  ProcessingCls = CudaProcessing
157
- vkffs_cls = vk_cufft
172
+ get_fft_obj = get_vkfft_cuda
158
173
 
159
174
  def _init_backend(self, backend_options):
160
175
  super()._init_backend(backend_options)
@@ -175,6 +190,7 @@ def _has_vkfft(x):
175
190
  return avail
176
191
 
177
192
 
193
+ @lru_cache(maxsize=2)
178
194
  def has_vkfft(safe=True):
179
195
  """
180
196
  Determine whether pyvkfft is available.
@@ -208,6 +224,7 @@ def _has_skfft(x):
208
224
  return avail
209
225
 
210
226
 
227
+ @lru_cache(maxsize=2)
211
228
  def has_skcuda(safe=True):
212
229
  """
213
230
  Determine whether scikit-cuda/CUFFT is available.
@@ -227,6 +244,7 @@ def has_skcuda(safe=True):
227
244
  return v
228
245
 
229
246
 
247
+ @lru_cache(maxsize=2)
230
248
  def get_fft_class(backend="vkfft"):
231
249
  backends = {
232
250
  "scikit-cuda": SKCUFFT,
@@ -256,6 +274,7 @@ def get_fft_class(backend="vkfft"):
256
274
  return get_fft_cls(backend)
257
275
 
258
276
 
277
+ @lru_cache(maxsize=1)
259
278
  def get_available_fft_implems():
260
279
  avail_implems = []
261
280
  if has_vkfft(safe=True):
@@ -1,15 +1,32 @@
1
+ from functools import lru_cache
2
+ import os
1
3
  from multiprocessing import get_context
2
4
  from multiprocessing.pool import Pool
5
+
6
+ from ..utils import BaseClassError, no_decorator
3
7
  from .fft_base import _BaseVKFFT
4
8
  from ..opencl.processing import OpenCLProcessing
5
9
 
6
10
  try:
7
- from pyvkfft.opencl import VkFFTApp as vk_clfft
11
+ from pyvkfft.opencl import VkFFTApp as OpenCLVkFFTApp
8
12
 
9
13
  __has_vkfft__ = True
10
14
  except (ImportError, OSError):
11
15
  __has_vkfft__ = False
12
16
  vk_clfft = None
17
+ OpenCLVkFFTApp = BaseClassError
18
+
19
+ n_cached_ffts = int(os.getenv("NABU_FFT_CACHE", "0"))
20
+ maybe_cached = lru_cache(maxsize=n_cached_ffts) if n_cached_ffts > 0 else no_decorator
21
+
22
+
23
+ @maybe_cached
24
+ def _get_vkfft_opencl(*args, **kwargs):
25
+ return OpenCLVkFFTApp(*args, **kwargs)
26
+
27
+
28
+ def get_vkfft_opencl(slf, *args, **kwargs):
29
+ return _get_vkfft_opencl(*args, **kwargs)
13
30
 
14
31
 
15
32
  class VKCLFFT(_BaseVKFFT):
@@ -20,7 +37,7 @@ class VKCLFFT(_BaseVKFFT):
20
37
  implem = "vkfft"
21
38
  backend = "opencl"
22
39
  ProcessingCls = OpenCLProcessing
23
- vkffs_cls = vk_clfft
40
+ get_fft_obj = get_vkfft_opencl
24
41
 
25
42
  def _init_backend(self, backend_options):
26
43
  super()._init_backend(backend_options)
@@ -11,7 +11,6 @@ class CudaPadding(PaddingBase):
11
11
 
12
12
  backend = "cuda"
13
13
 
14
- # TODO docstring from base class
15
14
  def __init__(self, shape, pad_width, mode="constant", cuda_options=None, **kwargs):
16
15
  super().__init__(shape, pad_width, mode=mode, **kwargs)
17
16
  self.cuda_processing = self.processing = CudaProcessing(**(cuda_options or {}))
@@ -99,6 +99,15 @@ class ProcessingBase:
99
99
  _recover_arrays_references = recover_arrays_references
100
100
  _allocate_array = allocate_array
101
101
  _set_array = set_array
102
+ # --
103
+
104
+ def is_contiguous(self, arr):
105
+ if isinstance(arr, self.array_class):
106
+ return arr.flags.c_contiguous
107
+ elif isinstance(arr, np.ndarray):
108
+ return arr.flags["C_CONTIGUOUS"]
109
+ else:
110
+ raise TypeError
102
111
 
103
112
  def check_array(self, arr, expected_shape, expected_dtype="f", check_contiguous=True):
104
113
  """
@@ -108,11 +117,8 @@ class ProcessingBase:
108
117
  raise ValueError("Expected shape %s but got %s" % (str(expected_shape), str(arr.shape)))
109
118
  if arr.dtype != np.dtype(expected_dtype):
110
119
  raise ValueError("Expected data type %s but got %s" % (str(expected_dtype), str(arr.dtype)))
111
- if check_contiguous:
112
- if isinstance(arr, np.ndarray) and not (arr.flags["C_CONTIGUOUS"]):
113
- raise ValueError("Expected C-contiguous array")
114
- if isinstance(arr, self.array_class) and not arr.flags.c_contiguous:
115
- raise ValueError("Expected C-contiguous array")
120
+ if check_contiguous and not (self.is_contiguous(arr)):
121
+ raise ValueError("Expected C-contiguous array")
116
122
 
117
123
  def kernel(self, *args, **kwargs):
118
124
  raise ValueError("Base class")
@@ -0,0 +1,245 @@
1
+ # ruff: noqa
2
+ try:
3
+ import astra
4
+
5
+ __have_astra__ = True
6
+ except ImportError:
7
+ __have_astra__ = False
8
+ astra = None
9
+
10
+
11
+ class AstraReconstructor:
12
+ """
13
+ Base class for reconstructors based on the Astra toolbox
14
+ """
15
+
16
+ default_extra_options = {
17
+ "axis_correction": None,
18
+ "clip_outer_circle": False,
19
+ "scale_factor": None,
20
+ "filter_cutoff": 1.0,
21
+ "outer_circle_value": 0.0,
22
+ }
23
+
24
+ def __init__(
25
+ self,
26
+ sinos_shape,
27
+ angles=None,
28
+ volume_shape=None,
29
+ rot_center=None,
30
+ pixel_size=None,
31
+ padding_mode="zeros",
32
+ filter_name=None,
33
+ slice_roi=None,
34
+ cuda_options=None,
35
+ extra_options=None,
36
+ ):
37
+ self._configure_extra_options(extra_options)
38
+ self._init_cuda(cuda_options)
39
+ self._set_sino_shape(sinos_shape)
40
+ self._orig_prog_geom = None
41
+ self._init_geometry(
42
+ source_origin_dist,
43
+ origin_detector_dist,
44
+ pixel_size,
45
+ angles,
46
+ volume_shape,
47
+ rot_center,
48
+ relative_z_position,
49
+ slice_roi,
50
+ )
51
+ self._init_fdk(padding_mode, filter_name)
52
+ self._alg_id = None
53
+ self._vol_id = None
54
+ self._proj_id = None
55
+
56
+ def _configure_extra_options(self, extra_options):
57
+ self.extra_options = self.default_extra_options.copy()
58
+ self.extra_options.update(extra_options or {})
59
+
60
+ def _init_cuda(self, cuda_options):
61
+ cuda_options = cuda_options or {}
62
+ self.cuda = CudaProcessing(**cuda_options)
63
+
64
+ def _set_sino_shape(self, sinos_shape):
65
+ if len(sinos_shape) != 3:
66
+ raise ValueError("Expected a 3D shape")
67
+ self.sinos_shape = sinos_shape
68
+ self.n_sinos, self.n_angles, self.prj_width = sinos_shape
69
+
70
+ def _set_pixel_size(self, pixel_size):
71
+ if pixel_size is None:
72
+ det_spacing_y = det_spacing_x = 1
73
+ elif np.iterable(pixel_size):
74
+ det_spacing_y, det_spacing_x = pixel_size
75
+ else:
76
+ # assuming scalar
77
+ det_spacing_y = det_spacing_x = pixel_size
78
+ self._det_spacing_y = det_spacing_y
79
+ self._det_spacing_x = det_spacing_x
80
+
81
+ def _set_slice_roi(self, slice_roi):
82
+ self.slice_roi = slice_roi
83
+ self._vol_geom_n_x = self.n_x
84
+ self._vol_geom_n_y = self.n_y
85
+ self._crop_data = True
86
+ if slice_roi is None:
87
+ return
88
+ start_x, end_x, start_y, end_y = slice_roi
89
+ if roi_is_centered(self.volume_shape[1:], (slice(start_y, end_y), slice(start_x, end_x))):
90
+ # Astra can only reconstruct subregion centered around the origin
91
+ self._vol_geom_n_x = self.n_x - start_x * 2
92
+ self._vol_geom_n_y = self.n_y - start_y * 2
93
+ else:
94
+ raise NotImplementedError(
95
+ "Astra supports only slice_roi centered around origin (got slice_roi=%s with n_x=%d, n_y=%d)"
96
+ % (str(slice_roi), self.n_x, self.n_y)
97
+ )
98
+
99
+ def _init_geometry(
100
+ self,
101
+ source_origin_dist,
102
+ origin_detector_dist,
103
+ pixel_size,
104
+ angles,
105
+ volume_shape,
106
+ rot_center,
107
+ relative_z_position,
108
+ slice_roi,
109
+ ):
110
+ if angles is None:
111
+ self.angles = np.linspace(0, 2 * np.pi, self.n_angles, endpoint=True)
112
+ else:
113
+ self.angles = angles
114
+ if volume_shape is None:
115
+ volume_shape = (self.sinos_shape[0], self.sinos_shape[2], self.sinos_shape[2])
116
+ self.volume_shape = volume_shape
117
+ self.n_z, self.n_y, self.n_x = self.volume_shape
118
+ self.source_origin_dist = source_origin_dist
119
+ self.origin_detector_dist = origin_detector_dist
120
+ self.magnification = 1 + origin_detector_dist / source_origin_dist
121
+ self._set_slice_roi(slice_roi)
122
+ self.vol_geom = astra.create_vol_geom(self._vol_geom_n_y, self._vol_geom_n_x, self.n_z)
123
+ self.vol_shape = astra.geom_size(self.vol_geom)
124
+ self._cor_shift = 0.0
125
+ self.rot_center = rot_center
126
+ if rot_center is not None:
127
+ self._cor_shift = (self.sinos_shape[-1] - 1) / 2.0 - rot_center
128
+ self._set_pixel_size(pixel_size)
129
+ self._axis_corrections = self.extra_options.get("axis_correction", None)
130
+ self._create_astra_proj_geometry(relative_z_position)
131
+
132
+ def _create_astra_proj_geometry(self, relative_z_position):
133
+ # This object has to be re-created each time, because once the modifications below are done,
134
+ # it is no more a "cone" geometry but a "cone_vec" geometry, and cannot be updated subsequently
135
+ # (see astra/functions.py:271)
136
+ self.proj_geom = astra.create_proj_geom(
137
+ "cone",
138
+ self._det_spacing_x,
139
+ self._det_spacing_y,
140
+ self.n_sinos,
141
+ self.prj_width,
142
+ self.angles,
143
+ self.source_origin_dist,
144
+ self.origin_detector_dist,
145
+ )
146
+ self.relative_z_position = relative_z_position or 0.0
147
+ # This will turn the geometry of type "cone" into a geometry of type "cone_vec"
148
+ if self._orig_prog_geom is None:
149
+ self._orig_prog_geom = self.proj_geom
150
+ self.proj_geom = astra.geom_postalignment(self.proj_geom, (self._cor_shift, 0))
151
+ # (src, detector_center, u, v) = (srcX, srcY, srcZ, dX, dY, dZ, uX, uY, uZ, vX, vY, vZ)
152
+ vecs = self.proj_geom["Vectors"]
153
+
154
+ # To adapt the center of rotation:
155
+ # dX = cor_shift * cos(theta) - origin_detector_dist * sin(theta)
156
+ # dY = origin_detector_dist * cos(theta) + cor_shift * sin(theta)
157
+ if self._axis_corrections is not None:
158
+ # should we check that dX and dY match the above formulas ?
159
+ cor_shifts = self._cor_shift + self._axis_corrections
160
+ vecs[:, 3] = cor_shifts * np.cos(self.angles) - self.origin_detector_dist * np.sin(self.angles)
161
+ vecs[:, 4] = self.origin_detector_dist * np.cos(self.angles) + cor_shifts * np.sin(self.angles)
162
+
163
+ # To adapt the z position:
164
+ # Component 2 of vecs is the z coordinate of the source, component 5 is the z component of the detector position
165
+ # We need to re-create the same inclination of the cone beam, thus we need to keep the inclination of the two z positions.
166
+ # The detector is centered on the rotation axis, thus moving it up or down, just moves it out of the reconstruction volume.
167
+ # We can bring back the detector in the correct volume position, by applying a rigid translation of both the detector and the source.
168
+ # The translation is exactly the amount that brought the detector up or down, but in the opposite direction.
169
+ vecs[:, 2] = -self.relative_z_position
170
+
171
+ def _set_output(self, volume):
172
+ if volume is not None:
173
+ expected_shape = self.vol_shape # if not (self._crop_data) else self._output_cropped_shape
174
+ self.cuda.check_array(volume, expected_shape)
175
+ self.cuda.set_array("output", volume)
176
+ if volume is None:
177
+ self.cuda.allocate_array("output", self.vol_shape)
178
+ d_volume = self.cuda.get_array("output")
179
+ z, y, x = d_volume.shape
180
+ self._vol_link = astra.data3d.GPULink(d_volume.ptr, x, y, z, d_volume.strides[-2])
181
+ self._vol_id = astra.data3d.link("-vol", self.vol_geom, self._vol_link)
182
+
183
+ def _set_input(self, sinos):
184
+ self.cuda.check_array(sinos, self.sinos_shape)
185
+ self.cuda.set_array("sinos", sinos) # self.cuda.sinos is now a GPU array
186
+ # TODO don't create new link/proj_id if ptr is the same ?
187
+ # But it seems Astra modifies the input sinogram while doing FDK, so this might be not relevant
188
+ d_sinos = self.cuda.get_array("sinos")
189
+
190
+ # self._proj_data_link = astra.data3d.GPULink(d_sinos.ptr, self.prj_width, self.n_angles, self.n_z, sinos.strides[-2])
191
+ self._proj_data_link = astra.data3d.GPULink(
192
+ d_sinos.ptr, self.prj_width, self.n_angles, self.n_sinos, d_sinos.strides[-2]
193
+ )
194
+ self._proj_id = astra.data3d.link("-sino", self.proj_geom, self._proj_data_link)
195
+
196
+ def _preprocess_data(self):
197
+ d_sinos = self.cuda.sinos
198
+ for i in range(d_sinos.shape[0]):
199
+ self.sino_filter.filter_sino(d_sinos[i], output=d_sinos[i])
200
+
201
+ def _update_reconstruction(self):
202
+ cfg = astra.astra_dict("BP3D_CUDA")
203
+ cfg["ReconstructionDataId"] = self._vol_id
204
+ cfg["ProjectionDataId"] = self._proj_id
205
+ if self._alg_id is not None:
206
+ astra.algorithm.delete(self._alg_id)
207
+ self._alg_id = astra.algorithm.create(cfg)
208
+
209
+ def reconstruct(self, sinos, output=None, relative_z_position=None):
210
+ """
211
+ sinos: numpy.ndarray or pycuda.gpuarray
212
+ Sinograms, with shape (n_sinograms, n_angles, width)
213
+ output: pycuda.gpuarray, optional
214
+ Output array. If not provided, a new numpy array is returned
215
+ relative_z_position: int, optional
216
+ Position of the central slice of the slab, with respect to the full stack of slices.
217
+ By default it is set to zero, meaning that the current slab is assumed in the middle of the stack
218
+ """
219
+ self._create_astra_proj_geometry(relative_z_position)
220
+ self._set_input(sinos)
221
+ self._set_output(output)
222
+ self._preprocess_data()
223
+ self._update_reconstruction()
224
+ astra.algorithm.run(self._alg_id)
225
+ #
226
+ # NB: Could also be done with
227
+ # from astra.experimental import direct_BP3D
228
+ # projector_id = astra.create_projector("cuda3d", self.proj_geom, self.vol_geom, options=None)
229
+ # direct_BP3D(projector_id, self._vol_link, self._proj_data_link)
230
+ #
231
+ result = self.cuda.get_array("output")
232
+ if output is None:
233
+ result = result.get()
234
+ if self.extra_options.get("scale_factor", None) is not None:
235
+ result *= np.float32(self.extra_options["scale_factor"]) # in-place for pycuda
236
+ self.cuda.recover_arrays_references(["sinos", "output"])
237
+ return result
238
+
239
+ def __del__(self):
240
+ if getattr(self, "_alg_id", None) is not None:
241
+ astra.algorithm.delete(self._alg_id)
242
+ if getattr(self, "_vol_id", None) is not None:
243
+ astra.data3d.delete(self._vol_id)
244
+ if getattr(self, "_proj_id", None) is not None:
245
+ astra.data3d.delete(self._proj_id)