nabu 2025.1.0.dev5__py3-none-any.whl → 2025.1.0.dev12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doc/doc_config.py +32 -0
- nabu/__init__.py +1 -1
- nabu/app/double_flatfield.py +18 -5
- nabu/app/reconstruct_helical.py +4 -4
- nabu/app/stitching.py +7 -2
- nabu/cuda/src/backproj.cu +10 -10
- nabu/cuda/src/cone.cu +4 -0
- nabu/cuda/utils.py +1 -1
- nabu/estimation/cor.py +3 -3
- nabu/io/cast_volume.py +13 -0
- nabu/io/reader.py +3 -2
- nabu/opencl/src/backproj.cl +10 -10
- nabu/pipeline/estimators.py +6 -6
- nabu/pipeline/fullfield/chunked.py +13 -13
- nabu/pipeline/fullfield/computations.py +4 -1
- nabu/pipeline/fullfield/get_double_flatfield.py +147 -0
- nabu/pipeline/fullfield/nabu_config.py +16 -4
- nabu/pipeline/fullfield/processconfig.py +22 -2
- nabu/pipeline/fullfield/reconstruction.py +9 -4
- nabu/pipeline/helical/gridded_accumulator.py +1 -1
- nabu/pipeline/helical/helical_reconstruction.py +2 -2
- nabu/pipeline/helical/nabu_config.py +1 -1
- nabu/pipeline/helical/weight_balancer.py +1 -1
- nabu/pipeline/params.py +8 -3
- nabu/preproc/shift.py +1 -1
- nabu/processing/fft_base.py +6 -2
- nabu/processing/fft_cuda.py +23 -4
- nabu/processing/fft_opencl.py +19 -2
- nabu/processing/padding_cuda.py +0 -1
- nabu/processing/processing_base.py +11 -5
- nabu/reconstruction/astra.py +245 -0
- nabu/reconstruction/cone.py +34 -9
- nabu/reconstruction/fbp.py +7 -0
- nabu/reconstruction/fbp_base.py +8 -0
- nabu/reconstruction/filtering.py +59 -25
- nabu/reconstruction/filtering_cuda.py +21 -20
- nabu/reconstruction/filtering_opencl.py +8 -14
- nabu/reconstruction/hbp.py +10 -10
- nabu/reconstruction/rings_cuda.py +41 -13
- nabu/reconstruction/tests/test_cone.py +35 -0
- nabu/reconstruction/tests/test_fbp.py +32 -11
- nabu/reconstruction/tests/test_filtering.py +14 -5
- nabu/resources/dataset_analyzer.py +34 -2
- nabu/resources/tests/test_extract.py +4 -2
- nabu/stitching/config.py +6 -1
- nabu/stitching/stitcher/dumper/__init__.py +1 -0
- nabu/stitching/stitcher/dumper/postprocessing.py +105 -1
- nabu/stitching/stitcher/post_processing.py +14 -4
- nabu/stitching/stitcher/pre_processing.py +1 -1
- nabu/stitching/stitcher/single_axis.py +8 -7
- nabu/stitching/stitcher/z_stitcher.py +8 -4
- nabu/stitching/utils/utils.py +2 -2
- nabu/testutils.py +2 -2
- nabu/utils.py +9 -2
- {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev12.dist-info}/METADATA +9 -28
- {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev12.dist-info}/RECORD +60 -57
- {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev12.dist-info}/WHEEL +1 -1
- {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev12.dist-info}/entry_points.txt +0 -0
- {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev12.dist-info/licenses}/LICENSE +0 -0
- {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev12.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
|
|
1
1
|
import os
|
2
2
|
import posixpath
|
3
3
|
import numpy as np
|
4
|
+
from .get_double_flatfield import get_double_flatfield
|
4
5
|
from silx.io import get_data
|
5
6
|
from silx.io.url import DataUrl
|
6
7
|
from ...utils import copy_dict_items, compare_dicts
|
@@ -32,6 +33,7 @@ class ProcessConfig(ProcessConfigBase):
|
|
32
33
|
|
33
34
|
(2) update_dataset_info_with_user_config
|
34
35
|
- Update flats/darks
|
36
|
+
- Double-flat-field
|
35
37
|
- CoR (value or estimation method) # no estimation yet
|
36
38
|
- rotation angles
|
37
39
|
- translations files
|
@@ -89,12 +91,29 @@ class ProcessConfig(ProcessConfigBase):
|
|
89
91
|
self.subsampling_factor = subsampling_factor or 1
|
90
92
|
self.subsampling_start = subsampling_start or 0
|
91
93
|
|
94
|
+
self._get_double_flatfield()
|
92
95
|
self._update_dataset_with_user_overwrites()
|
93
96
|
self._get_rotation_axis_position()
|
94
97
|
self._update_rotation_angles()
|
95
98
|
self._get_translation_file("reconstruction", "translation_movements_file", "translations")
|
96
99
|
self._get_user_sino_normalization()
|
97
100
|
|
101
|
+
def _get_double_flatfield(self):
|
102
|
+
self._dff_file = None
|
103
|
+
dff_mode = self.nabu_config["preproc"]["double_flatfield"]
|
104
|
+
if not (dff_mode):
|
105
|
+
return
|
106
|
+
self._dff_file = get_double_flatfield(
|
107
|
+
self.dataset_info,
|
108
|
+
dff_mode,
|
109
|
+
output_dir=self.nabu_config["output"]["location"],
|
110
|
+
darks_flats_dir=self.nabu_config["dataset"]["darks_flats_dir"],
|
111
|
+
dff_options={
|
112
|
+
"dff_sigma": self.nabu_config["preproc"]["dff_sigma"],
|
113
|
+
"do_flatfield": (self.nabu_config["preproc"]["flatfield"] is not False),
|
114
|
+
},
|
115
|
+
)
|
116
|
+
|
98
117
|
def _update_dataset_with_user_overwrites(self):
|
99
118
|
user_overwrites = self.nabu_config["dataset"]["overwrite_metadata"].strip()
|
100
119
|
if user_overwrites in ("", None):
|
@@ -451,11 +470,11 @@ class ProcessConfig(ProcessConfigBase):
|
|
451
470
|
#
|
452
471
|
# Double flat field
|
453
472
|
#
|
454
|
-
if nabu_config["preproc"]["
|
473
|
+
if nabu_config["preproc"]["double_flatfield"]:
|
455
474
|
tasks.append("double_flatfield")
|
456
475
|
options["double_flatfield"] = {
|
457
476
|
"sigma": nabu_config["preproc"]["dff_sigma"],
|
458
|
-
"processes_file": nabu_config["preproc"]["processes_file"],
|
477
|
+
"processes_file": self._dff_file or nabu_config["preproc"]["processes_file"],
|
459
478
|
"log_min_clip": nabu_config["preproc"]["log_min_clip"],
|
460
479
|
"log_max_clip": nabu_config["preproc"]["log_max_clip"],
|
461
480
|
}
|
@@ -575,6 +594,7 @@ class ProcessConfig(ProcessConfigBase):
|
|
575
594
|
"sample_detector_dist",
|
576
595
|
"hbp_legs",
|
577
596
|
"hbp_reduction_steps",
|
597
|
+
"crop_filtered_data",
|
578
598
|
],
|
579
599
|
)
|
580
600
|
rec_options = options["reconstruction"]
|
@@ -120,7 +120,7 @@ class FullFieldReconstructor:
|
|
120
120
|
vm = virtual_memory()
|
121
121
|
self.resources["mem_avail_GB"] = vm.available / 1e9
|
122
122
|
# Account for other memory constraints. There might be a better way
|
123
|
-
slurm_mem_constraint_MB = int(environ.get("SLURM_MEM_PER_NODE", 0))
|
123
|
+
slurm_mem_constraint_MB = int(environ.get("SLURM_MEM_PER_NODE", 0)) # noqa: PLW1508
|
124
124
|
if slurm_mem_constraint_MB > 0:
|
125
125
|
self.resources["mem_avail_GB"] = slurm_mem_constraint_MB / 1e3
|
126
126
|
#
|
@@ -131,8 +131,12 @@ class FullFieldReconstructor:
|
|
131
131
|
self.resources["gpus"] = avail_gpus
|
132
132
|
if len(avail_gpus) == 0:
|
133
133
|
return
|
134
|
-
|
135
|
-
|
134
|
+
user_gpus = self.process_config.nabu_config.get("resources", {}).get("gpu_id", [])
|
135
|
+
if len(user_gpus) == 0:
|
136
|
+
user_gpus = [0]
|
137
|
+
# For now nabu does not support multi-GPU reconstruction. Take the first one.
|
138
|
+
user_gpu_idx = user_gpus[0]
|
139
|
+
self.resources["gpu_id"] = self._gpu_id = list(avail_gpus.keys())[user_gpu_idx]
|
136
140
|
|
137
141
|
def _get_backend(self, backend, cuda_options):
|
138
142
|
self._pipeline_cls = ChunkedPipeline
|
@@ -145,6 +149,7 @@ class FullFieldReconstructor:
|
|
145
149
|
backend = "numpy"
|
146
150
|
else:
|
147
151
|
self.gpu_mem = self.resources["gpus"][self._gpu_id]["memory_GB"] * self.gpu_mem_fraction
|
152
|
+
self.cuda_options = {"device_id": self._gpu_id}
|
148
153
|
if backend == "cuda":
|
149
154
|
if not (__has_pycuda__):
|
150
155
|
raise RuntimeError("pycuda not avilable")
|
@@ -307,7 +312,7 @@ class FullFieldReconstructor:
|
|
307
312
|
sigma = opts["unsharp_sigma"]
|
308
313
|
# nabu uses cutoff = 4
|
309
314
|
cutoff = 4
|
310
|
-
gaussian_kernel_size =
|
315
|
+
gaussian_kernel_size = ceil(2 * cutoff * sigma + 1)
|
311
316
|
self.logger.debug("Unsharp mask margin: %d pixels" % gaussian_kernel_size)
|
312
317
|
return (gaussian_kernel_size, gaussian_kernel_size)
|
313
318
|
|
@@ -532,7 +532,7 @@ def get_reconstruction_space(span_info, min_scanwise_z, end_scanwise_z, phase_ma
|
|
532
532
|
# regridded dataset, estimating a meaningul angular step representative
|
533
533
|
# of the raw data
|
534
534
|
my_angle_step = abs(np.diff(span_info.projection_angles_deg).mean())
|
535
|
-
n_gridded_angles =
|
535
|
+
n_gridded_angles = round(360.0 / my_angle_step)
|
536
536
|
|
537
537
|
radios_h = phase_margin_pix + (my_z_end - my_z_min) + phase_margin_pix
|
538
538
|
|
@@ -168,8 +168,8 @@ class HelicalReconstructorRegridded:
|
|
168
168
|
|
169
169
|
# the meaming of z_min and z_max is: position in slices units from the
|
170
170
|
# first available slice and in the direction of the scan
|
171
|
-
self.z_min =
|
172
|
-
self.z_max =
|
171
|
+
self.z_min = round(z_start * (0 - z_fract_min) + z_max * z_fract_min)
|
172
|
+
self.z_max = round(z_start * (0 - z_fract_max) + z_max * z_fract_max) + 1
|
173
173
|
|
174
174
|
def _compute_translations_margin(self):
|
175
175
|
return 0, 0
|
@@ -43,7 +43,7 @@ nabu_config["preproc"]["processes_file"] = {
|
|
43
43
|
"validator": optional_file_location_validator,
|
44
44
|
"type": "required",
|
45
45
|
}
|
46
|
-
nabu_config["preproc"]["
|
46
|
+
nabu_config["preproc"]["double_flatfield"]["default"] = 1
|
47
47
|
|
48
48
|
|
49
49
|
nabu_config["reconstruction"].update(
|
nabu/pipeline/params.py
CHANGED
@@ -25,12 +25,17 @@ unsharp_methods = {
|
|
25
25
|
"": None,
|
26
26
|
}
|
27
27
|
|
28
|
+
# see PaddingBase.supported_modes
|
28
29
|
padding_modes = {
|
29
|
-
"edges": "edge",
|
30
|
-
"edge": "edge",
|
31
|
-
"mirror": "mirror",
|
32
30
|
"zeros": "zeros",
|
33
31
|
"zero": "zeros",
|
32
|
+
"constant": "zeros",
|
33
|
+
"edges": "edge",
|
34
|
+
"edge": "edge",
|
35
|
+
"mirror": "reflect",
|
36
|
+
"reflect": "reflect",
|
37
|
+
"symmetric": "symmetric",
|
38
|
+
"wrap": "wrap",
|
34
39
|
}
|
35
40
|
|
36
41
|
reconstruction_methods = {
|
nabu/preproc/shift.py
CHANGED
nabu/processing/fft_base.py
CHANGED
@@ -93,6 +93,10 @@ class _BaseFFT:
|
|
93
93
|
pass
|
94
94
|
|
95
95
|
|
96
|
+
def raise_base_class_error(slf, *args, **kwargs):
|
97
|
+
raise ValueError
|
98
|
+
|
99
|
+
|
96
100
|
class _BaseVKFFT(_BaseFFT):
|
97
101
|
"""
|
98
102
|
FFT using VKFFT backend
|
@@ -101,7 +105,7 @@ class _BaseVKFFT(_BaseFFT):
|
|
101
105
|
implem = "vkfft"
|
102
106
|
backend = "none"
|
103
107
|
ProcessingCls = BaseClassError
|
104
|
-
|
108
|
+
get_fft_obj = raise_base_class_error
|
105
109
|
|
106
110
|
def _configure_batched_transform(self):
|
107
111
|
if self.axes is not None and len(self.shape) == len(self.axes):
|
@@ -128,7 +132,7 @@ class _BaseVKFFT(_BaseFFT):
|
|
128
132
|
self._vkfft_ndim = None
|
129
133
|
|
130
134
|
def _compute_fft_plans(self):
|
131
|
-
self._vkfft_plan = self.
|
135
|
+
self._vkfft_plan = self.get_fft_obj(
|
132
136
|
self.shape,
|
133
137
|
self.dtype,
|
134
138
|
ndim=self._vkfft_ndim,
|
nabu/processing/fft_cuda.py
CHANGED
@@ -1,18 +1,19 @@
|
|
1
1
|
import os
|
2
2
|
import warnings
|
3
|
+
from functools import lru_cache
|
3
4
|
from multiprocessing import get_context
|
4
5
|
from multiprocessing.pool import Pool
|
5
6
|
import numpy as np
|
6
|
-
from ..utils import check_supported
|
7
|
+
from ..utils import BaseClassError, check_supported, no_decorator
|
7
8
|
from .fft_base import _BaseFFT, _BaseVKFFT
|
8
9
|
|
9
10
|
try:
|
10
|
-
from pyvkfft.cuda import VkFFTApp as
|
11
|
+
from pyvkfft.cuda import VkFFTApp as CudaVkFFTApp
|
11
12
|
|
12
13
|
__has_vkfft__ = True
|
13
14
|
except (ImportError, OSError):
|
14
15
|
__has_vkfft__ = False
|
15
|
-
|
16
|
+
CudaVkFFTApp = BaseClassError
|
16
17
|
from ..cuda.processing import CudaProcessing
|
17
18
|
|
18
19
|
Plan = None
|
@@ -20,6 +21,8 @@ cu_fft = None
|
|
20
21
|
cu_ifft = None
|
21
22
|
__has_skcuda__ = None
|
22
23
|
|
24
|
+
n_cached_ffts = int(os.getenv("NABU_FFT_CACHE", "0"))
|
25
|
+
|
23
26
|
|
24
27
|
def init_skcuda():
|
25
28
|
# This needs to be done here, because scikit-cuda creates a Cuda context at import,
|
@@ -146,6 +149,18 @@ class SKCUFFT(_BaseFFT):
|
|
146
149
|
return output
|
147
150
|
|
148
151
|
|
152
|
+
maybe_cached = lru_cache(maxsize=n_cached_ffts) if n_cached_ffts > 0 else no_decorator
|
153
|
+
|
154
|
+
|
155
|
+
@maybe_cached
|
156
|
+
def _get_vkfft_cuda(*args, **kwargs):
|
157
|
+
return CudaVkFFTApp(*args, **kwargs)
|
158
|
+
|
159
|
+
|
160
|
+
def get_vkfft_cuda(slf, *args, **kwargs):
|
161
|
+
return _get_vkfft_cuda(*args, **kwargs)
|
162
|
+
|
163
|
+
|
149
164
|
class VKCUFFT(_BaseVKFFT):
|
150
165
|
"""
|
151
166
|
Cuda FFT, using VKFFT backend
|
@@ -154,7 +169,7 @@ class VKCUFFT(_BaseVKFFT):
|
|
154
169
|
implem = "vkfft"
|
155
170
|
backend = "cuda"
|
156
171
|
ProcessingCls = CudaProcessing
|
157
|
-
|
172
|
+
get_fft_obj = get_vkfft_cuda
|
158
173
|
|
159
174
|
def _init_backend(self, backend_options):
|
160
175
|
super()._init_backend(backend_options)
|
@@ -175,6 +190,7 @@ def _has_vkfft(x):
|
|
175
190
|
return avail
|
176
191
|
|
177
192
|
|
193
|
+
@lru_cache(maxsize=2)
|
178
194
|
def has_vkfft(safe=True):
|
179
195
|
"""
|
180
196
|
Determine whether pyvkfft is available.
|
@@ -208,6 +224,7 @@ def _has_skfft(x):
|
|
208
224
|
return avail
|
209
225
|
|
210
226
|
|
227
|
+
@lru_cache(maxsize=2)
|
211
228
|
def has_skcuda(safe=True):
|
212
229
|
"""
|
213
230
|
Determine whether scikit-cuda/CUFFT is available.
|
@@ -227,6 +244,7 @@ def has_skcuda(safe=True):
|
|
227
244
|
return v
|
228
245
|
|
229
246
|
|
247
|
+
@lru_cache(maxsize=2)
|
230
248
|
def get_fft_class(backend="vkfft"):
|
231
249
|
backends = {
|
232
250
|
"scikit-cuda": SKCUFFT,
|
@@ -256,6 +274,7 @@ def get_fft_class(backend="vkfft"):
|
|
256
274
|
return get_fft_cls(backend)
|
257
275
|
|
258
276
|
|
277
|
+
@lru_cache(maxsize=1)
|
259
278
|
def get_available_fft_implems():
|
260
279
|
avail_implems = []
|
261
280
|
if has_vkfft(safe=True):
|
nabu/processing/fft_opencl.py
CHANGED
@@ -1,15 +1,32 @@
|
|
1
|
+
from functools import lru_cache
|
2
|
+
import os
|
1
3
|
from multiprocessing import get_context
|
2
4
|
from multiprocessing.pool import Pool
|
5
|
+
|
6
|
+
from ..utils import BaseClassError, no_decorator
|
3
7
|
from .fft_base import _BaseVKFFT
|
4
8
|
from ..opencl.processing import OpenCLProcessing
|
5
9
|
|
6
10
|
try:
|
7
|
-
from pyvkfft.opencl import VkFFTApp as
|
11
|
+
from pyvkfft.opencl import VkFFTApp as OpenCLVkFFTApp
|
8
12
|
|
9
13
|
__has_vkfft__ = True
|
10
14
|
except (ImportError, OSError):
|
11
15
|
__has_vkfft__ = False
|
12
16
|
vk_clfft = None
|
17
|
+
OpenCLVkFFTApp = BaseClassError
|
18
|
+
|
19
|
+
n_cached_ffts = int(os.getenv("NABU_FFT_CACHE", "0"))
|
20
|
+
maybe_cached = lru_cache(maxsize=n_cached_ffts) if n_cached_ffts > 0 else no_decorator
|
21
|
+
|
22
|
+
|
23
|
+
@maybe_cached
|
24
|
+
def _get_vkfft_opencl(*args, **kwargs):
|
25
|
+
return OpenCLVkFFTApp(*args, **kwargs)
|
26
|
+
|
27
|
+
|
28
|
+
def get_vkfft_opencl(slf, *args, **kwargs):
|
29
|
+
return _get_vkfft_opencl(*args, **kwargs)
|
13
30
|
|
14
31
|
|
15
32
|
class VKCLFFT(_BaseVKFFT):
|
@@ -20,7 +37,7 @@ class VKCLFFT(_BaseVKFFT):
|
|
20
37
|
implem = "vkfft"
|
21
38
|
backend = "opencl"
|
22
39
|
ProcessingCls = OpenCLProcessing
|
23
|
-
|
40
|
+
get_fft_obj = get_vkfft_opencl
|
24
41
|
|
25
42
|
def _init_backend(self, backend_options):
|
26
43
|
super()._init_backend(backend_options)
|
nabu/processing/padding_cuda.py
CHANGED
@@ -11,7 +11,6 @@ class CudaPadding(PaddingBase):
|
|
11
11
|
|
12
12
|
backend = "cuda"
|
13
13
|
|
14
|
-
# TODO docstring from base class
|
15
14
|
def __init__(self, shape, pad_width, mode="constant", cuda_options=None, **kwargs):
|
16
15
|
super().__init__(shape, pad_width, mode=mode, **kwargs)
|
17
16
|
self.cuda_processing = self.processing = CudaProcessing(**(cuda_options or {}))
|
@@ -99,6 +99,15 @@ class ProcessingBase:
|
|
99
99
|
_recover_arrays_references = recover_arrays_references
|
100
100
|
_allocate_array = allocate_array
|
101
101
|
_set_array = set_array
|
102
|
+
# --
|
103
|
+
|
104
|
+
def is_contiguous(self, arr):
|
105
|
+
if isinstance(arr, self.array_class):
|
106
|
+
return arr.flags.c_contiguous
|
107
|
+
elif isinstance(arr, np.ndarray):
|
108
|
+
return arr.flags["C_CONTIGUOUS"]
|
109
|
+
else:
|
110
|
+
raise TypeError
|
102
111
|
|
103
112
|
def check_array(self, arr, expected_shape, expected_dtype="f", check_contiguous=True):
|
104
113
|
"""
|
@@ -108,11 +117,8 @@ class ProcessingBase:
|
|
108
117
|
raise ValueError("Expected shape %s but got %s" % (str(expected_shape), str(arr.shape)))
|
109
118
|
if arr.dtype != np.dtype(expected_dtype):
|
110
119
|
raise ValueError("Expected data type %s but got %s" % (str(expected_dtype), str(arr.dtype)))
|
111
|
-
if check_contiguous:
|
112
|
-
|
113
|
-
raise ValueError("Expected C-contiguous array")
|
114
|
-
if isinstance(arr, self.array_class) and not arr.flags.c_contiguous:
|
115
|
-
raise ValueError("Expected C-contiguous array")
|
120
|
+
if check_contiguous and not (self.is_contiguous(arr)):
|
121
|
+
raise ValueError("Expected C-contiguous array")
|
116
122
|
|
117
123
|
def kernel(self, *args, **kwargs):
|
118
124
|
raise ValueError("Base class")
|
@@ -0,0 +1,245 @@
|
|
1
|
+
# ruff: noqa
|
2
|
+
try:
|
3
|
+
import astra
|
4
|
+
|
5
|
+
__have_astra__ = True
|
6
|
+
except ImportError:
|
7
|
+
__have_astra__ = False
|
8
|
+
astra = None
|
9
|
+
|
10
|
+
|
11
|
+
class AstraReconstructor:
|
12
|
+
"""
|
13
|
+
Base class for reconstructors based on the Astra toolbox
|
14
|
+
"""
|
15
|
+
|
16
|
+
default_extra_options = {
|
17
|
+
"axis_correction": None,
|
18
|
+
"clip_outer_circle": False,
|
19
|
+
"scale_factor": None,
|
20
|
+
"filter_cutoff": 1.0,
|
21
|
+
"outer_circle_value": 0.0,
|
22
|
+
}
|
23
|
+
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
sinos_shape,
|
27
|
+
angles=None,
|
28
|
+
volume_shape=None,
|
29
|
+
rot_center=None,
|
30
|
+
pixel_size=None,
|
31
|
+
padding_mode="zeros",
|
32
|
+
filter_name=None,
|
33
|
+
slice_roi=None,
|
34
|
+
cuda_options=None,
|
35
|
+
extra_options=None,
|
36
|
+
):
|
37
|
+
self._configure_extra_options(extra_options)
|
38
|
+
self._init_cuda(cuda_options)
|
39
|
+
self._set_sino_shape(sinos_shape)
|
40
|
+
self._orig_prog_geom = None
|
41
|
+
self._init_geometry(
|
42
|
+
source_origin_dist,
|
43
|
+
origin_detector_dist,
|
44
|
+
pixel_size,
|
45
|
+
angles,
|
46
|
+
volume_shape,
|
47
|
+
rot_center,
|
48
|
+
relative_z_position,
|
49
|
+
slice_roi,
|
50
|
+
)
|
51
|
+
self._init_fdk(padding_mode, filter_name)
|
52
|
+
self._alg_id = None
|
53
|
+
self._vol_id = None
|
54
|
+
self._proj_id = None
|
55
|
+
|
56
|
+
def _configure_extra_options(self, extra_options):
|
57
|
+
self.extra_options = self.default_extra_options.copy()
|
58
|
+
self.extra_options.update(extra_options or {})
|
59
|
+
|
60
|
+
def _init_cuda(self, cuda_options):
|
61
|
+
cuda_options = cuda_options or {}
|
62
|
+
self.cuda = CudaProcessing(**cuda_options)
|
63
|
+
|
64
|
+
def _set_sino_shape(self, sinos_shape):
|
65
|
+
if len(sinos_shape) != 3:
|
66
|
+
raise ValueError("Expected a 3D shape")
|
67
|
+
self.sinos_shape = sinos_shape
|
68
|
+
self.n_sinos, self.n_angles, self.prj_width = sinos_shape
|
69
|
+
|
70
|
+
def _set_pixel_size(self, pixel_size):
|
71
|
+
if pixel_size is None:
|
72
|
+
det_spacing_y = det_spacing_x = 1
|
73
|
+
elif np.iterable(pixel_size):
|
74
|
+
det_spacing_y, det_spacing_x = pixel_size
|
75
|
+
else:
|
76
|
+
# assuming scalar
|
77
|
+
det_spacing_y = det_spacing_x = pixel_size
|
78
|
+
self._det_spacing_y = det_spacing_y
|
79
|
+
self._det_spacing_x = det_spacing_x
|
80
|
+
|
81
|
+
def _set_slice_roi(self, slice_roi):
|
82
|
+
self.slice_roi = slice_roi
|
83
|
+
self._vol_geom_n_x = self.n_x
|
84
|
+
self._vol_geom_n_y = self.n_y
|
85
|
+
self._crop_data = True
|
86
|
+
if slice_roi is None:
|
87
|
+
return
|
88
|
+
start_x, end_x, start_y, end_y = slice_roi
|
89
|
+
if roi_is_centered(self.volume_shape[1:], (slice(start_y, end_y), slice(start_x, end_x))):
|
90
|
+
# Astra can only reconstruct subregion centered around the origin
|
91
|
+
self._vol_geom_n_x = self.n_x - start_x * 2
|
92
|
+
self._vol_geom_n_y = self.n_y - start_y * 2
|
93
|
+
else:
|
94
|
+
raise NotImplementedError(
|
95
|
+
"Astra supports only slice_roi centered around origin (got slice_roi=%s with n_x=%d, n_y=%d)"
|
96
|
+
% (str(slice_roi), self.n_x, self.n_y)
|
97
|
+
)
|
98
|
+
|
99
|
+
def _init_geometry(
|
100
|
+
self,
|
101
|
+
source_origin_dist,
|
102
|
+
origin_detector_dist,
|
103
|
+
pixel_size,
|
104
|
+
angles,
|
105
|
+
volume_shape,
|
106
|
+
rot_center,
|
107
|
+
relative_z_position,
|
108
|
+
slice_roi,
|
109
|
+
):
|
110
|
+
if angles is None:
|
111
|
+
self.angles = np.linspace(0, 2 * np.pi, self.n_angles, endpoint=True)
|
112
|
+
else:
|
113
|
+
self.angles = angles
|
114
|
+
if volume_shape is None:
|
115
|
+
volume_shape = (self.sinos_shape[0], self.sinos_shape[2], self.sinos_shape[2])
|
116
|
+
self.volume_shape = volume_shape
|
117
|
+
self.n_z, self.n_y, self.n_x = self.volume_shape
|
118
|
+
self.source_origin_dist = source_origin_dist
|
119
|
+
self.origin_detector_dist = origin_detector_dist
|
120
|
+
self.magnification = 1 + origin_detector_dist / source_origin_dist
|
121
|
+
self._set_slice_roi(slice_roi)
|
122
|
+
self.vol_geom = astra.create_vol_geom(self._vol_geom_n_y, self._vol_geom_n_x, self.n_z)
|
123
|
+
self.vol_shape = astra.geom_size(self.vol_geom)
|
124
|
+
self._cor_shift = 0.0
|
125
|
+
self.rot_center = rot_center
|
126
|
+
if rot_center is not None:
|
127
|
+
self._cor_shift = (self.sinos_shape[-1] - 1) / 2.0 - rot_center
|
128
|
+
self._set_pixel_size(pixel_size)
|
129
|
+
self._axis_corrections = self.extra_options.get("axis_correction", None)
|
130
|
+
self._create_astra_proj_geometry(relative_z_position)
|
131
|
+
|
132
|
+
def _create_astra_proj_geometry(self, relative_z_position):
|
133
|
+
# This object has to be re-created each time, because once the modifications below are done,
|
134
|
+
# it is no more a "cone" geometry but a "cone_vec" geometry, and cannot be updated subsequently
|
135
|
+
# (see astra/functions.py:271)
|
136
|
+
self.proj_geom = astra.create_proj_geom(
|
137
|
+
"cone",
|
138
|
+
self._det_spacing_x,
|
139
|
+
self._det_spacing_y,
|
140
|
+
self.n_sinos,
|
141
|
+
self.prj_width,
|
142
|
+
self.angles,
|
143
|
+
self.source_origin_dist,
|
144
|
+
self.origin_detector_dist,
|
145
|
+
)
|
146
|
+
self.relative_z_position = relative_z_position or 0.0
|
147
|
+
# This will turn the geometry of type "cone" into a geometry of type "cone_vec"
|
148
|
+
if self._orig_prog_geom is None:
|
149
|
+
self._orig_prog_geom = self.proj_geom
|
150
|
+
self.proj_geom = astra.geom_postalignment(self.proj_geom, (self._cor_shift, 0))
|
151
|
+
# (src, detector_center, u, v) = (srcX, srcY, srcZ, dX, dY, dZ, uX, uY, uZ, vX, vY, vZ)
|
152
|
+
vecs = self.proj_geom["Vectors"]
|
153
|
+
|
154
|
+
# To adapt the center of rotation:
|
155
|
+
# dX = cor_shift * cos(theta) - origin_detector_dist * sin(theta)
|
156
|
+
# dY = origin_detector_dist * cos(theta) + cor_shift * sin(theta)
|
157
|
+
if self._axis_corrections is not None:
|
158
|
+
# should we check that dX and dY match the above formulas ?
|
159
|
+
cor_shifts = self._cor_shift + self._axis_corrections
|
160
|
+
vecs[:, 3] = cor_shifts * np.cos(self.angles) - self.origin_detector_dist * np.sin(self.angles)
|
161
|
+
vecs[:, 4] = self.origin_detector_dist * np.cos(self.angles) + cor_shifts * np.sin(self.angles)
|
162
|
+
|
163
|
+
# To adapt the z position:
|
164
|
+
# Component 2 of vecs is the z coordinate of the source, component 5 is the z component of the detector position
|
165
|
+
# We need to re-create the same inclination of the cone beam, thus we need to keep the inclination of the two z positions.
|
166
|
+
# The detector is centered on the rotation axis, thus moving it up or down, just moves it out of the reconstruction volume.
|
167
|
+
# We can bring back the detector in the correct volume position, by applying a rigid translation of both the detector and the source.
|
168
|
+
# The translation is exactly the amount that brought the detector up or down, but in the opposite direction.
|
169
|
+
vecs[:, 2] = -self.relative_z_position
|
170
|
+
|
171
|
+
def _set_output(self, volume):
|
172
|
+
if volume is not None:
|
173
|
+
expected_shape = self.vol_shape # if not (self._crop_data) else self._output_cropped_shape
|
174
|
+
self.cuda.check_array(volume, expected_shape)
|
175
|
+
self.cuda.set_array("output", volume)
|
176
|
+
if volume is None:
|
177
|
+
self.cuda.allocate_array("output", self.vol_shape)
|
178
|
+
d_volume = self.cuda.get_array("output")
|
179
|
+
z, y, x = d_volume.shape
|
180
|
+
self._vol_link = astra.data3d.GPULink(d_volume.ptr, x, y, z, d_volume.strides[-2])
|
181
|
+
self._vol_id = astra.data3d.link("-vol", self.vol_geom, self._vol_link)
|
182
|
+
|
183
|
+
def _set_input(self, sinos):
|
184
|
+
self.cuda.check_array(sinos, self.sinos_shape)
|
185
|
+
self.cuda.set_array("sinos", sinos) # self.cuda.sinos is now a GPU array
|
186
|
+
# TODO don't create new link/proj_id if ptr is the same ?
|
187
|
+
# But it seems Astra modifies the input sinogram while doing FDK, so this might be not relevant
|
188
|
+
d_sinos = self.cuda.get_array("sinos")
|
189
|
+
|
190
|
+
# self._proj_data_link = astra.data3d.GPULink(d_sinos.ptr, self.prj_width, self.n_angles, self.n_z, sinos.strides[-2])
|
191
|
+
self._proj_data_link = astra.data3d.GPULink(
|
192
|
+
d_sinos.ptr, self.prj_width, self.n_angles, self.n_sinos, d_sinos.strides[-2]
|
193
|
+
)
|
194
|
+
self._proj_id = astra.data3d.link("-sino", self.proj_geom, self._proj_data_link)
|
195
|
+
|
196
|
+
def _preprocess_data(self):
|
197
|
+
d_sinos = self.cuda.sinos
|
198
|
+
for i in range(d_sinos.shape[0]):
|
199
|
+
self.sino_filter.filter_sino(d_sinos[i], output=d_sinos[i])
|
200
|
+
|
201
|
+
def _update_reconstruction(self):
|
202
|
+
cfg = astra.astra_dict("BP3D_CUDA")
|
203
|
+
cfg["ReconstructionDataId"] = self._vol_id
|
204
|
+
cfg["ProjectionDataId"] = self._proj_id
|
205
|
+
if self._alg_id is not None:
|
206
|
+
astra.algorithm.delete(self._alg_id)
|
207
|
+
self._alg_id = astra.algorithm.create(cfg)
|
208
|
+
|
209
|
+
def reconstruct(self, sinos, output=None, relative_z_position=None):
|
210
|
+
"""
|
211
|
+
sinos: numpy.ndarray or pycuda.gpuarray
|
212
|
+
Sinograms, with shape (n_sinograms, n_angles, width)
|
213
|
+
output: pycuda.gpuarray, optional
|
214
|
+
Output array. If not provided, a new numpy array is returned
|
215
|
+
relative_z_position: int, optional
|
216
|
+
Position of the central slice of the slab, with respect to the full stack of slices.
|
217
|
+
By default it is set to zero, meaning that the current slab is assumed in the middle of the stack
|
218
|
+
"""
|
219
|
+
self._create_astra_proj_geometry(relative_z_position)
|
220
|
+
self._set_input(sinos)
|
221
|
+
self._set_output(output)
|
222
|
+
self._preprocess_data()
|
223
|
+
self._update_reconstruction()
|
224
|
+
astra.algorithm.run(self._alg_id)
|
225
|
+
#
|
226
|
+
# NB: Could also be done with
|
227
|
+
# from astra.experimental import direct_BP3D
|
228
|
+
# projector_id = astra.create_projector("cuda3d", self.proj_geom, self.vol_geom, options=None)
|
229
|
+
# direct_BP3D(projector_id, self._vol_link, self._proj_data_link)
|
230
|
+
#
|
231
|
+
result = self.cuda.get_array("output")
|
232
|
+
if output is None:
|
233
|
+
result = result.get()
|
234
|
+
if self.extra_options.get("scale_factor", None) is not None:
|
235
|
+
result *= np.float32(self.extra_options["scale_factor"]) # in-place for pycuda
|
236
|
+
self.cuda.recover_arrays_references(["sinos", "output"])
|
237
|
+
return result
|
238
|
+
|
239
|
+
def __del__(self):
|
240
|
+
if getattr(self, "_alg_id", None) is not None:
|
241
|
+
astra.algorithm.delete(self._alg_id)
|
242
|
+
if getattr(self, "_vol_id", None) is not None:
|
243
|
+
astra.data3d.delete(self._vol_id)
|
244
|
+
if getattr(self, "_proj_id", None) is not None:
|
245
|
+
astra.data3d.delete(self._proj_id)
|