nabu 2024.1.10__py3-none-any.whl → 2024.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nabu/__init__.py +1 -1
- nabu/app/bootstrap.py +2 -3
- nabu/app/cast_volume.py +4 -2
- nabu/app/cli_configs.py +5 -0
- nabu/app/composite_cor.py +1 -1
- nabu/app/create_distortion_map_from_poly.py +5 -6
- nabu/app/diag_to_pix.py +7 -19
- nabu/app/diag_to_rot.py +14 -29
- nabu/app/double_flatfield.py +32 -44
- nabu/app/parse_reconstruction_log.py +3 -0
- nabu/app/reconstruct.py +53 -15
- nabu/app/reconstruct_helical.py +2 -2
- nabu/app/stitching.py +27 -13
- nabu/app/tests/__init__.py +0 -0
- nabu/app/tests/test_reduce_dark_flat.py +4 -1
- nabu/cuda/kernel.py +11 -2
- nabu/cuda/processing.py +2 -2
- nabu/cuda/src/cone.cu +77 -0
- nabu/cuda/src/hierarchical_backproj.cu +271 -0
- nabu/cuda/utils.py +0 -6
- nabu/estimation/alignment.py +5 -19
- nabu/estimation/cor.py +173 -599
- nabu/estimation/cor_sino.py +356 -26
- nabu/estimation/focus.py +63 -11
- nabu/estimation/tests/test_cor.py +124 -58
- nabu/estimation/tests/test_focus.py +6 -6
- nabu/estimation/tilt.py +2 -1
- nabu/estimation/utils.py +5 -33
- nabu/io/__init__.py +1 -1
- nabu/io/cast_volume.py +1 -1
- nabu/io/reader.py +416 -21
- nabu/io/tests/test_readers.py +422 -0
- nabu/io/tests/test_writers.py +1 -102
- nabu/io/writer.py +4 -433
- nabu/opencl/kernel.py +14 -3
- nabu/opencl/processing.py +8 -0
- nabu/pipeline/config_validators.py +5 -2
- nabu/pipeline/datadump.py +12 -5
- nabu/pipeline/estimators.py +162 -188
- nabu/pipeline/fullfield/chunked.py +168 -92
- nabu/pipeline/fullfield/chunked_cuda.py +7 -3
- nabu/pipeline/fullfield/computations.py +2 -7
- nabu/pipeline/fullfield/dataset_validator.py +0 -4
- nabu/pipeline/fullfield/nabu_config.py +37 -13
- nabu/pipeline/fullfield/processconfig.py +22 -13
- nabu/pipeline/fullfield/reconstruction.py +13 -9
- nabu/pipeline/helical/helical_chunked_regridded.py +1 -1
- nabu/pipeline/helical/helical_chunked_regridded_cuda.py +1 -0
- nabu/pipeline/helical/helical_reconstruction.py +1 -1
- nabu/pipeline/params.py +21 -1
- nabu/pipeline/processconfig.py +1 -12
- nabu/pipeline/reader.py +146 -0
- nabu/pipeline/tests/test_estimators.py +44 -72
- nabu/pipeline/utils.py +4 -2
- nabu/pipeline/writer.py +10 -2
- nabu/preproc/ccd_cuda.py +1 -1
- nabu/preproc/ctf.py +14 -7
- nabu/preproc/ctf_cuda.py +2 -3
- nabu/preproc/double_flatfield.py +5 -12
- nabu/preproc/double_flatfield_cuda.py +2 -2
- nabu/preproc/flatfield.py +5 -1
- nabu/preproc/flatfield_cuda.py +5 -1
- nabu/preproc/phase.py +24 -73
- nabu/preproc/phase_cuda.py +5 -8
- nabu/preproc/tests/test_ctf.py +11 -7
- nabu/preproc/tests/test_flatfield.py +67 -122
- nabu/preproc/tests/test_paganin.py +54 -30
- nabu/processing/azim.py +206 -0
- nabu/processing/convolution_cuda.py +1 -1
- nabu/processing/fft_cuda.py +15 -17
- nabu/processing/histogram.py +2 -0
- nabu/processing/histogram_cuda.py +2 -1
- nabu/processing/kernel_base.py +3 -0
- nabu/processing/muladd_cuda.py +1 -0
- nabu/processing/padding_opencl.py +1 -1
- nabu/processing/roll_opencl.py +1 -0
- nabu/processing/rotation_cuda.py +2 -2
- nabu/processing/tests/test_fft.py +17 -10
- nabu/processing/unsharp_cuda.py +1 -1
- nabu/reconstruction/cone.py +104 -40
- nabu/reconstruction/fbp.py +3 -0
- nabu/reconstruction/fbp_base.py +7 -2
- nabu/reconstruction/filtering.py +20 -7
- nabu/reconstruction/filtering_cuda.py +7 -1
- nabu/reconstruction/hbp.py +424 -0
- nabu/reconstruction/mlem.py +99 -0
- nabu/reconstruction/reconstructor.py +2 -0
- nabu/reconstruction/rings_cuda.py +19 -19
- nabu/reconstruction/sinogram_cuda.py +1 -0
- nabu/reconstruction/sinogram_opencl.py +3 -1
- nabu/reconstruction/tests/test_cone.py +10 -5
- nabu/reconstruction/tests/test_deringer.py +7 -6
- nabu/reconstruction/tests/test_fbp.py +124 -10
- nabu/reconstruction/tests/test_filtering.py +13 -11
- nabu/reconstruction/tests/test_halftomo.py +30 -4
- nabu/reconstruction/tests/test_mlem.py +91 -0
- nabu/reconstruction/tests/test_reconstructor.py +8 -3
- nabu/resources/dataset_analyzer.py +142 -92
- nabu/resources/gpu.py +1 -0
- nabu/resources/nxflatfield.py +134 -125
- nabu/resources/templates/id16a_fluo.conf +42 -0
- nabu/resources/tests/test_extract.py +10 -0
- nabu/resources/tests/test_nxflatfield.py +2 -2
- nabu/stitching/alignment.py +80 -24
- nabu/stitching/config.py +105 -68
- nabu/stitching/definitions.py +1 -0
- nabu/stitching/frame_composition.py +68 -60
- nabu/stitching/overlap.py +91 -51
- nabu/stitching/single_axis_stitching.py +32 -0
- nabu/stitching/slurm_utils.py +6 -6
- nabu/stitching/stitcher/__init__.py +0 -0
- nabu/stitching/stitcher/base.py +124 -0
- nabu/stitching/stitcher/dumper/__init__.py +3 -0
- nabu/stitching/stitcher/dumper/base.py +94 -0
- nabu/stitching/stitcher/dumper/postprocessing.py +356 -0
- nabu/stitching/stitcher/dumper/preprocessing.py +60 -0
- nabu/stitching/stitcher/post_processing.py +555 -0
- nabu/stitching/stitcher/pre_processing.py +1068 -0
- nabu/stitching/stitcher/single_axis.py +484 -0
- nabu/stitching/stitcher/stitcher.py +0 -0
- nabu/stitching/stitcher/y_stitcher.py +13 -0
- nabu/stitching/stitcher/z_stitcher.py +45 -0
- nabu/stitching/stitcher_2D.py +278 -0
- nabu/stitching/tests/test_config.py +12 -37
- nabu/stitching/tests/test_frame_composition.py +33 -59
- nabu/stitching/tests/test_overlap.py +149 -7
- nabu/stitching/tests/test_utils.py +1 -1
- nabu/stitching/tests/test_y_preprocessing_stitching.py +132 -0
- nabu/stitching/tests/{test_z_stitching.py → test_z_postprocessing_stitching.py} +167 -561
- nabu/stitching/tests/test_z_preprocessing_stitching.py +431 -0
- nabu/stitching/utils/__init__.py +1 -0
- nabu/stitching/utils/post_processing.py +281 -0
- nabu/stitching/utils/tests/test_post-processing.py +21 -0
- nabu/stitching/{utils.py → utils/utils.py} +79 -52
- nabu/stitching/y_stitching.py +27 -0
- nabu/stitching/z_stitching.py +32 -2281
- nabu/testutils.py +1 -152
- nabu/thirdparty/tomocupy_remove_stripe.py +43 -9
- nabu/utils.py +158 -61
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/METADATA +24 -17
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/RECORD +145 -121
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/WHEEL +1 -1
- nabu/io/tiffwriter_zmm.py +0 -99
- nabu/pipeline/fallback_utils.py +0 -149
- nabu/pipeline/helical/tests/test_accumulator.py +0 -158
- nabu/pipeline/helical/tests/test_pipeline_elements_full.py +0 -355
- nabu/pipeline/helical/tests/test_strategy.py +0 -61
- nabu/pipeline/helical/utils.py +0 -51
- nabu/pipeline/tests/test_chunk_reader.py +0 -74
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/LICENSE +0 -0
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/entry_points.txt +0 -0
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/top_level.txt +0 -0
nabu/pipeline/writer.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
from os import path
|
2
2
|
from tomoscan.esrf import TIFFVolume, MultiTIFFVolume, EDFVolume, JP2KVolume
|
3
|
+
from tomoscan.esrf.volume.singleframebase import VolumeSingleFrameBase
|
3
4
|
from ..utils import check_supported, get_num_threads
|
4
5
|
from ..resources.logger import LoggerOrPrint
|
5
6
|
from ..io.writer import NXProcessWriter, HSTVolVolume, NXVolVolume
|
@@ -8,7 +9,6 @@ from .params import files_formats
|
|
8
9
|
|
9
10
|
|
10
11
|
class WriterManager:
|
11
|
-
|
12
12
|
"""
|
13
13
|
This class is a wrapper on top of all "writers".
|
14
14
|
It will create the right "writer" with all the necessary options, and the histogram writer.
|
@@ -114,6 +114,7 @@ class WriterManager:
|
|
114
114
|
return vol_writer.data_url.file_path()
|
115
115
|
|
116
116
|
def _init_writer(self):
|
117
|
+
self._writer_was_already_initialized = self.extra_options.get("writer_initialized", False)
|
117
118
|
if self.file_format in ["tiff", "edf", "jp2", "hdf5"]:
|
118
119
|
writer_kwargs = {
|
119
120
|
"folder": self.output_dir,
|
@@ -144,6 +145,11 @@ class WriterManager:
|
|
144
145
|
self._h5_entry = self.metadata.get("entry", "entry")
|
145
146
|
self.writer = self._writer_classes[self.file_format](**writer_kwargs)
|
146
147
|
self.fname = self.get_fname(self.writer)
|
148
|
+
# In certain cases, tomoscan needs to remove any previous existing volume filess
|
149
|
+
# and avoid calling 'clean_output_data' when writing downstream (for chunk processing)
|
150
|
+
if isinstance(self.writer, VolumeSingleFrameBase):
|
151
|
+
self.writer.skip_existing_data_files_removal = self._writer_was_already_initialized
|
152
|
+
# ---
|
147
153
|
if path.exists(self.fname):
|
148
154
|
err = "File already exists: %s" % self.fname
|
149
155
|
if self.overwrite:
|
@@ -188,7 +194,9 @@ class WriterManager:
|
|
188
194
|
self.writer.metadata = self.metadata
|
189
195
|
self.writer.save_metadata()
|
190
196
|
|
191
|
-
def write_data(self, data):
|
197
|
+
def write_data(self, data, metadata=None):
|
192
198
|
self.writer.data = data
|
199
|
+
if metadata is not None:
|
200
|
+
self.writer.metadata = metadata
|
193
201
|
self.writer.save()
|
194
202
|
# self._write_metadata()
|
nabu/preproc/ccd_cuda.py
CHANGED
@@ -118,7 +118,7 @@ class CudaLog(Log):
|
|
118
118
|
self._nthreadsperblock = (16, 16, 4) # TODO tune ?
|
119
119
|
self._nblocks = tuple([updiv(n, p) for n, p in zip([nx, ny, nz], self._nthreadsperblock)])
|
120
120
|
|
121
|
-
self.nlog_kernel = CudaKernel(
|
121
|
+
self.nlog_kernel = CudaKernel( # pylint: disable=E0606
|
122
122
|
"nlog",
|
123
123
|
filename=self._nlog_srcfile,
|
124
124
|
signature="Piiiff",
|
nabu/preproc/ctf.py
CHANGED
@@ -18,7 +18,7 @@ class GeoPars:
|
|
18
18
|
self,
|
19
19
|
z1_vh=None,
|
20
20
|
z2=None,
|
21
|
-
pix_size_det=
|
21
|
+
pix_size_det=1e-6,
|
22
22
|
wavelength=None,
|
23
23
|
magnification=True,
|
24
24
|
length_scale=10.0e-6,
|
@@ -33,8 +33,9 @@ class GeoPars:
|
|
33
33
|
and the horizontaly focused source (vertical line) for KB mirrors.
|
34
34
|
z2 : float
|
35
35
|
the sample detector distance (meters).
|
36
|
-
pix_size_det: float
|
37
|
-
pixel size
|
36
|
+
pix_size_det: float or tuple
|
37
|
+
pixel size in meters.
|
38
|
+
If a tuple is passed, it is interpreted as (horizontal_size, vertical_size)
|
38
39
|
wavelength: float
|
39
40
|
beam wave length (meters).
|
40
41
|
magnification: boolean defaults to True
|
@@ -55,7 +56,11 @@ class GeoPars:
|
|
55
56
|
self.z1_vh = np.array([z1_vh, z1_vh])
|
56
57
|
self.z2 = z2
|
57
58
|
self.magnification = magnification
|
58
|
-
|
59
|
+
if np.isscalar(pix_size_det):
|
60
|
+
self.pix_size_det_xy = (pix_size_det, pix_size_det)
|
61
|
+
else:
|
62
|
+
self.pix_size_det_xy = pix_size_det
|
63
|
+
self.pix_size_det = self.pix_size_det_xy[0] # COMPAT
|
59
64
|
|
60
65
|
if self.magnification and self.z1_vh is not None:
|
61
66
|
self.M_vh = (self.z1_vh + self.z2) / self.z1_vh
|
@@ -69,7 +74,9 @@ class GeoPars:
|
|
69
74
|
|
70
75
|
self.maxM = self.M_vh.max()
|
71
76
|
|
72
|
-
|
77
|
+
# we bring everything to highest magnification
|
78
|
+
self.pix_size_rec_xy = [p / self.maxM for p in self.pix_size_det_xy]
|
79
|
+
self.pix_size_rec = self.pix_size_rec_xy[0] # COMPAT
|
73
80
|
|
74
81
|
which_unit = int(np.sum(np.array([self.pix_size_rec > small for small in [1.0e-6, 1.0e-7]]).astype(np.int32)))
|
75
82
|
self.pixelsize_string = [
|
@@ -208,8 +215,8 @@ class CTFPhaseRetrieval:
|
|
208
215
|
padded_img_shape = self.shape_padded
|
209
216
|
fsample_vh = np.array(
|
210
217
|
[
|
211
|
-
self.geo_pars.length_scale / self.geo_pars.
|
212
|
-
self.geo_pars.length_scale / self.geo_pars.
|
218
|
+
self.geo_pars.length_scale / self.geo_pars.pix_size_rec_xy[1],
|
219
|
+
self.geo_pars.length_scale / self.geo_pars.pix_size_rec_xy[0],
|
213
220
|
]
|
214
221
|
)
|
215
222
|
|
nabu/preproc/ctf_cuda.py
CHANGED
@@ -15,7 +15,6 @@ if __has_pycuda__:
|
|
15
15
|
# - better padding scheme (for now 2*shape)
|
16
16
|
# - rework inheritance scheme ? (base class SingleDistancePhaseRetrieval and its cuda counterpart)
|
17
17
|
class CudaCTFPhaseRetrieval(CTFPhaseRetrieval):
|
18
|
-
|
19
18
|
"""
|
20
19
|
Cuda back-end of CTFPhaseRetrieval
|
21
20
|
"""
|
@@ -37,7 +36,7 @@ class CudaCTFPhaseRetrieval(CTFPhaseRetrieval):
|
|
37
36
|
fft_num_threads=None,
|
38
37
|
logger=None,
|
39
38
|
cuda_options=None,
|
40
|
-
fft_backend="
|
39
|
+
fft_backend="vkfft",
|
41
40
|
):
|
42
41
|
"""
|
43
42
|
Initialize a CudaCTFPhaseRetrieval.
|
@@ -130,7 +129,7 @@ class CudaCTFPhaseRetrieval(CTFPhaseRetrieval):
|
|
130
129
|
self.set_input(image)
|
131
130
|
self.cuda_padding.pad(image, output=self.d_radio_padded)
|
132
131
|
if self.normalize_by_mean:
|
133
|
-
m = garray.sum(self.d_radio_padded).get() / np.prod(self.shape_padded)
|
132
|
+
m = garray.sum(self.d_radio_padded).get() / np.prod(self.shape_padded) # pylint: disable=E0606
|
134
133
|
self.d_radio_padded /= m
|
135
134
|
self.cufft.fft(self.d_radio_padded, output=self.d_radio_f)
|
136
135
|
self.cpxmult_kernel(*self._cpxmult_kernel_args, **self._cpxmult_kernel_kwargs)
|
nabu/preproc/double_flatfield.py
CHANGED
@@ -2,9 +2,9 @@ from os import path
|
|
2
2
|
import numpy as np
|
3
3
|
from scipy.ndimage import gaussian_filter
|
4
4
|
from silx.io.url import DataUrl
|
5
|
-
from ..utils import
|
6
|
-
from ..io.reader import
|
7
|
-
from ..io.writer import
|
5
|
+
from ..utils import check_shape, get_2D_3D_shape
|
6
|
+
from ..io.reader import HDF5Reader
|
7
|
+
from ..io.writer import NXProcessWriter
|
8
8
|
from .ccd import Log
|
9
9
|
|
10
10
|
|
@@ -64,12 +64,6 @@ class DoubleFlatField:
|
|
64
64
|
self._init_processing(input_is_mlog, output_is_mlog, average_is_on_log, sigma_filter, filter_mode)
|
65
65
|
self._computed = False
|
66
66
|
|
67
|
-
def _get_reader_writer_class(self):
|
68
|
-
ext = path.splitext(self.result_url.file_path())[-1].replace(".", "")
|
69
|
-
check_supported(ext, list(Writers.keys()), "file format")
|
70
|
-
self._writer_cls = Writers[ext]
|
71
|
-
self._reader_cls = Readers[ext]
|
72
|
-
|
73
67
|
def _load_dff_dump(self):
|
74
68
|
res = self.reader.get_data(self.result_url)
|
75
69
|
if self.detector_corrector is not None:
|
@@ -98,15 +92,14 @@ class DoubleFlatField:
|
|
98
92
|
self.reader = None
|
99
93
|
if self.result_url is None:
|
100
94
|
return
|
101
|
-
self._get_reader_writer_class()
|
102
95
|
if path.exists(result_url.file_path()):
|
103
96
|
if detector_corrector is None:
|
104
97
|
adapted_subregion = sub_region
|
105
98
|
else:
|
106
99
|
adapted_subregion = self.detector_corrector.get_adapted_subregion(sub_region)
|
107
|
-
self.reader =
|
100
|
+
self.reader = HDF5Reader(sub_region=adapted_subregion)
|
108
101
|
else:
|
109
|
-
self.writer =
|
102
|
+
self.writer = NXProcessWriter(self.result_url.file_path())
|
110
103
|
|
111
104
|
def _init_processing(self, input_is_mlog, output_is_mlog, average_is_on_log, sigma_filter, filter_mode):
|
112
105
|
self.input_is_mlog = input_is_mlog
|
@@ -59,7 +59,7 @@ class CudaDoubleFlatField(DoubleFlatField):
|
|
59
59
|
def _proc_expm(x, o):
|
60
60
|
o[:] = x[:]
|
61
61
|
o[:] *= -1
|
62
|
-
cumath.exp(o, out=o)
|
62
|
+
cumath.exp(o, out=o) # pylint: disable=E0606
|
63
63
|
return o
|
64
64
|
|
65
65
|
def _init_processing(self, input_is_mlog, output_is_mlog, average_is_on_log, sigma_filter, filter_mode):
|
@@ -108,7 +108,7 @@ class CudaDoubleFlatField(DoubleFlatField):
|
|
108
108
|
recompute: bool, optional
|
109
109
|
Whether to recompute the double flatfield if already computed.
|
110
110
|
"""
|
111
|
-
if not (isinstance(radios, garray.GPUArray)):
|
111
|
+
if not (isinstance(radios, garray.GPUArray)): # pylint: disable=E0606
|
112
112
|
raise ValueError("Expected pycuda.gpuarray.GPUArray for radios")
|
113
113
|
if self._computed and not (recompute):
|
114
114
|
return self.doubleflatfield
|
nabu/preproc/flatfield.py
CHANGED
@@ -2,7 +2,7 @@ from multiprocessing.pool import ThreadPool
|
|
2
2
|
from bisect import bisect_left
|
3
3
|
import numpy as np
|
4
4
|
from ..io.reader import load_images_from_dataurl_dict
|
5
|
-
from ..utils import check_supported, get_num_threads
|
5
|
+
from ..utils import check_supported, deprecated_class, get_num_threads
|
6
6
|
|
7
7
|
|
8
8
|
class FlatFieldArrays:
|
@@ -228,6 +228,7 @@ class FlatFieldArrays:
|
|
228
228
|
f_idx, weights = _interp_nearest(idx, prev_next)
|
229
229
|
elif self.interpolation == "linear":
|
230
230
|
f_idx, weights = _interp_linear(idx, prev_next)
|
231
|
+
# pylint: disable=E0606
|
231
232
|
self.flats_idx[i] = f_idx
|
232
233
|
self.flats_weights[i] = weights
|
233
234
|
|
@@ -376,6 +377,9 @@ class FlatFieldArrays:
|
|
376
377
|
FlatField = FlatFieldArrays
|
377
378
|
|
378
379
|
|
380
|
+
@deprecated_class(
|
381
|
+
"FlatFieldDataUrls is deprecated since 2024.2.0 and will be removed in a future version", do_print=True
|
382
|
+
)
|
379
383
|
class FlatFieldDataUrls(FlatField):
|
380
384
|
def __init__(
|
381
385
|
self,
|
nabu/preproc/flatfield_cuda.py
CHANGED
@@ -2,7 +2,7 @@ import numpy as np
|
|
2
2
|
|
3
3
|
from nabu.cuda.processing import CudaProcessing
|
4
4
|
from ..preproc.flatfield import FlatFieldArrays
|
5
|
-
from ..utils import get_cuda_srcfile
|
5
|
+
from ..utils import deprecated_class, get_cuda_srcfile
|
6
6
|
from ..io.reader import load_images_from_dataurl_dict
|
7
7
|
from ..cuda.utils import __has_pycuda__
|
8
8
|
|
@@ -114,6 +114,9 @@ class CudaFlatFieldArrays(FlatFieldArrays):
|
|
114
114
|
CudaFlatField = CudaFlatFieldArrays
|
115
115
|
|
116
116
|
|
117
|
+
@deprecated_class(
|
118
|
+
"CudaFlatFieldDataUrls is deprecated since version 2024.2.0 and will be removed in a future version", do_print=True
|
119
|
+
)
|
117
120
|
class CudaFlatFieldDataUrls(CudaFlatField):
|
118
121
|
def __init__(
|
119
122
|
self,
|
@@ -138,6 +141,7 @@ class CudaFlatFieldDataUrls(CudaFlatField):
|
|
138
141
|
radios_indices=radios_indices,
|
139
142
|
interpolation=interpolation,
|
140
143
|
distortion_correction=distortion_correction,
|
144
|
+
nan_value=nan_value,
|
141
145
|
radios_srcurrent=radios_srcurrent,
|
142
146
|
flats_srcurrent=flats_srcurrent,
|
143
147
|
cuda_options=cuda_options,
|
nabu/preproc/phase.py
CHANGED
@@ -4,7 +4,7 @@ import numpy as np
|
|
4
4
|
from scipy.fft import rfft2, irfft2, fft2, ifft2
|
5
5
|
from ..utils import generate_powers, get_decay, check_supported, get_num_threads, deprecation_warning
|
6
6
|
|
7
|
-
#
|
7
|
+
# COMPAT.
|
8
8
|
from .ctf import CTFPhaseRetrieval
|
9
9
|
|
10
10
|
#
|
@@ -50,7 +50,6 @@ class PaganinPhaseRetrieval:
|
|
50
50
|
delta_beta=250.0,
|
51
51
|
pixel_size=1e-6,
|
52
52
|
padding="edge",
|
53
|
-
margin=None,
|
54
53
|
use_rfft=True,
|
55
54
|
use_R2C=None,
|
56
55
|
fftw_num_threads=None,
|
@@ -73,42 +72,13 @@ class PaganinPhaseRetrieval:
|
|
73
72
|
delta_beta: float, optional
|
74
73
|
delta/beta ratio, where n = (1 - delta) + i*beta is the complex
|
75
74
|
refractive index of the sample.
|
76
|
-
pixel_size : float, optional
|
75
|
+
pixel_size : float or tuple, optional
|
77
76
|
Detector pixel size in meters. Default is 1e-6 (one micron)
|
77
|
+
If a tuple is passed, the pixel size is set as (horizontal_size, vertical_size).
|
78
78
|
padding : str, optional
|
79
79
|
Padding method. Available are "zeros", "mean", "edge", "sym",
|
80
80
|
"reflect". Default is "edge".
|
81
81
|
Please refer to the "Padding" section below for more details.
|
82
|
-
margin: tuple, optional
|
83
|
-
The user may provide integers values U, D, L, R as a tuple under the
|
84
|
-
form ((U, D), (L, R)) (same syntax as numpy.pad()).
|
85
|
-
The resulting filtered radio will have a size equal to
|
86
|
-
(size_vertic - U - D, size_horiz - L - R).
|
87
|
-
These values serve to create a "margin" for the filtering process,
|
88
|
-
where U, D, L R are the margin of the Up, Down, Left and Right part,
|
89
|
-
respectively.
|
90
|
-
The filtering is done on a subset of the input radio. The subset
|
91
|
-
size is (Nrows - U - D, Ncols - R - L).
|
92
|
-
The margins is used to do the padding for the rest of the padded
|
93
|
-
array.
|
94
|
-
|
95
|
-
For example in one dimension, where ``padding="edge"``::
|
96
|
-
|
97
|
-
<------------------------------ padded_size --------------------------->
|
98
|
-
[padding=edge | padding=data | radio data | padding=data | padding=edge]
|
99
|
-
<------ N2 ---><----- L -----><- (N-L-R)--><----- R -----><----- N2 --->
|
100
|
-
|
101
|
-
Some or all the values U, D, L, R can be 0. In this case,
|
102
|
-
the padding of the parts related to the zero values will
|
103
|
-
fall back to the one of "padding" parameter.
|
104
|
-
For example, if padding="edge" and L, R are 0, then
|
105
|
-
the left and right parts will be padded with the edges, while
|
106
|
-
the Up and Down parts will be padded using the the user-provided
|
107
|
-
margins of the radio, and the final data will have shape
|
108
|
-
(Nrows - U - D, Ncols).
|
109
|
-
Some or all the values U, D, L, R can be the string "auto".
|
110
|
-
In this case, the values of U, D, L, R are automatically computed
|
111
|
-
as a function of the Paganin filter width.
|
112
82
|
use_rfft: bool, optional
|
113
83
|
Whether to use Real-to-Complex (R2C) transform instead of
|
114
84
|
standard Complex-to-Complex transform, providing better performances
|
@@ -171,7 +141,7 @@ class PaganinPhaseRetrieval:
|
|
171
141
|
Journal of Microscopy, Vol 206, Part 1, 2002
|
172
142
|
"""
|
173
143
|
self._init_parameters(distance, energy, pixel_size, delta_beta, padding)
|
174
|
-
self._calc_shape(shape
|
144
|
+
self._calc_shape(shape)
|
175
145
|
# COMPAT.
|
176
146
|
if use_R2C is not None:
|
177
147
|
deprecation_warning("'use_R2C' is replaced with 'use_rfft'", func_name="pag_r2c")
|
@@ -186,7 +156,13 @@ class PaganinPhaseRetrieval:
|
|
186
156
|
self.distance_cm = distance * 1e2
|
187
157
|
self.distance_micron = distance * 1e6
|
188
158
|
self.energy_kev = energy
|
189
|
-
|
159
|
+
if np.isscalar(pixel_size):
|
160
|
+
self.pixel_size_xy_micron = (pixel_size * 1e6, pixel_size * 1e6)
|
161
|
+
else:
|
162
|
+
self.pixel_size_xy_micron = pixel_size * 1e6
|
163
|
+
# COMPAT.
|
164
|
+
self.pixel_size_micron = self.pixel_size_xy_micron[0]
|
165
|
+
#
|
190
166
|
self.delta_beta = delta_beta
|
191
167
|
self.wavelength_micron = 1.23984199e-3 / self.energy_kev
|
192
168
|
self.padding = padding
|
@@ -209,34 +185,14 @@ class PaganinPhaseRetrieval:
|
|
209
185
|
self.fft_func = fft2
|
210
186
|
self.ifft_func = ifft2
|
211
187
|
|
212
|
-
def _calc_shape(self, shape
|
188
|
+
def _calc_shape(self, shape):
|
213
189
|
if np.isscalar(shape):
|
214
190
|
shape = (shape, shape)
|
215
191
|
else:
|
216
192
|
assert len(shape) == 2
|
217
193
|
self.shape = shape
|
218
|
-
self._set_margin_value(margin)
|
219
194
|
self._calc_padded_shape()
|
220
195
|
|
221
|
-
def _set_margin_value(self, margin):
|
222
|
-
self.margin = margin
|
223
|
-
if margin is None:
|
224
|
-
self.shape_inner = self.shape
|
225
|
-
self.use_margin = False
|
226
|
-
self.margin = ((0, 0), (0, 0))
|
227
|
-
return
|
228
|
-
self.use_margin = True
|
229
|
-
try:
|
230
|
-
((U, D), (L, R)) = margin
|
231
|
-
except ValueError:
|
232
|
-
raise ValueError("Expected margin in the format ((U, D), (L, R))")
|
233
|
-
for val in [U, D, L, R]:
|
234
|
-
if isinstance(val, str) and val != "auto":
|
235
|
-
raise ValueError("Expected either an integer, or 'auto'")
|
236
|
-
if int(val) != val or val < 0:
|
237
|
-
raise ValueError("Expected positive integers for margin values")
|
238
|
-
self.shape_inner = (self.shape[0] - U - D, self.shape[1] - L - R)
|
239
|
-
|
240
196
|
def _calc_padded_shape(self):
|
241
197
|
"""
|
242
198
|
Compute the padded shape.
|
@@ -257,19 +213,15 @@ class PaganinPhaseRetrieval:
|
|
257
213
|
nx0 : length of original data
|
258
214
|
nx_p : total length of padded data
|
259
215
|
"""
|
260
|
-
n_y, n_x = self.
|
261
|
-
|
262
|
-
|
263
|
-
n_x_p = self._get_next_power(max(2 * n_x, n_x0))
|
216
|
+
n_y, n_x = self.shape
|
217
|
+
n_y_p = self._get_next_power(2 * n_y)
|
218
|
+
n_x_p = self._get_next_power(2 * n_x)
|
264
219
|
self.shape_padded = (n_y_p, n_x_p)
|
265
220
|
self.data_padded = np.zeros((n_y_p, n_x_p), dtype=np.float64)
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
self.
|
270
|
-
self.pad_bottom_len = n_y_p - n_y0 - self.pad_top_len
|
271
|
-
self.pad_left_len = (n_x_p - n_x0) // 2
|
272
|
-
self.pad_right_len = n_x_p - n_x0 - self.pad_left_len
|
221
|
+
self.pad_top_len = (n_y_p - n_y) // 2
|
222
|
+
self.pad_bottom_len = n_y_p - n_y - self.pad_top_len
|
223
|
+
self.pad_left_len = (n_x_p - n_x) // 2
|
224
|
+
self.pad_right_len = n_x_p - n_x - self.pad_left_len
|
273
225
|
|
274
226
|
def _get_next_power(self, n):
|
275
227
|
"""
|
@@ -284,8 +236,8 @@ class PaganinPhaseRetrieval:
|
|
284
236
|
def compute_filter(self):
|
285
237
|
nyp, nxp = self.shape_padded
|
286
238
|
fftfreq = np.fft.rfftfreq if self.use_rfft else np.fft.fftfreq
|
287
|
-
fy = np.fft.fftfreq(nyp, d=self.
|
288
|
-
fx = fftfreq(nxp, d=self.
|
239
|
+
fy = np.fft.fftfreq(nyp, d=self.pixel_size_xy_micron[1])
|
240
|
+
fx = fftfreq(nxp, d=self.pixel_size_xy_micron[0])
|
289
241
|
self._coords_grid = np.add.outer(fy**2, fx**2)
|
290
242
|
#
|
291
243
|
k2 = self._coords_grid
|
@@ -376,12 +328,11 @@ class PaganinPhaseRetrieval:
|
|
376
328
|
radio_f = self.fft_func(self.data_padded, workers=self.fft_num_threads)
|
377
329
|
radio_f *= self.paganin_filter
|
378
330
|
radio_filtered = self.ifft_func(radio_f, workers=self.fft_num_threads).real
|
379
|
-
s0, s1 = self.
|
380
|
-
((U, _), (L, _)) = self.margin
|
331
|
+
s0, s1 = self.shape
|
381
332
|
if output is None:
|
382
|
-
return radio_filtered[
|
333
|
+
return radio_filtered[:s0, :s1]
|
383
334
|
else:
|
384
|
-
output[:, :] = radio_filtered[
|
335
|
+
output[:, :] = radio_filtered[:s0, :s1]
|
385
336
|
return output
|
386
337
|
|
387
338
|
def lmicron_to_db(self, Lmicron):
|
nabu/preproc/phase_cuda.py
CHANGED
@@ -18,11 +18,10 @@ class CudaPaganinPhaseRetrieval(PaganinPhaseRetrieval):
|
|
18
18
|
delta_beta=250.0,
|
19
19
|
pixel_size=1e-6,
|
20
20
|
padding="edge",
|
21
|
-
margin=None,
|
22
21
|
cuda_options=None,
|
23
22
|
fftw_num_threads=None, # COMPAT.
|
24
23
|
fft_num_threads=None,
|
25
|
-
fft_backend="
|
24
|
+
fft_backend="vkfft",
|
26
25
|
):
|
27
26
|
"""
|
28
27
|
Please refer to the documentation of
|
@@ -37,7 +36,6 @@ class CudaPaganinPhaseRetrieval(PaganinPhaseRetrieval):
|
|
37
36
|
delta_beta=delta_beta,
|
38
37
|
pixel_size=pixel_size,
|
39
38
|
padding=padding,
|
40
|
-
margin=margin,
|
41
39
|
use_rfft=True,
|
42
40
|
fft_num_threads=False,
|
43
41
|
)
|
@@ -118,14 +116,13 @@ class CudaPaganinPhaseRetrieval(PaganinPhaseRetrieval):
|
|
118
116
|
raise ValueError("Expected either numpy array, pycuda array or pycuda buffer")
|
119
117
|
|
120
118
|
def get_output(self, output):
|
121
|
-
s0, s1 = self.
|
122
|
-
((U, _), (L, _)) = self.margin
|
119
|
+
s0, s1 = self.shape
|
123
120
|
if output is None:
|
124
121
|
# copy D2H
|
125
|
-
return self.d_radio_padded[
|
126
|
-
assert output.shape == self.
|
122
|
+
return self.d_radio_padded[:s0, :s1].get()
|
123
|
+
assert output.shape == self.shape
|
127
124
|
assert output.dtype == np.float32
|
128
|
-
output[:, :] = self.d_radio_padded[
|
125
|
+
output[:, :] = self.d_radio_padded[:s0, :s1]
|
129
126
|
return output
|
130
127
|
|
131
128
|
def apply_filter(self, radio, output=None):
|
nabu/preproc/tests/test_ctf.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import pytest
|
2
2
|
import numpy as np
|
3
3
|
import scipy.interpolate
|
4
|
+
from nabu.processing.fft_cuda import get_available_fft_implems
|
4
5
|
from nabu.testutils import get_data as nabu_get_data
|
5
6
|
from nabu.testutils import __do_long_tests__
|
6
7
|
from nabu.preproc.flatfield import FlatFieldArrays
|
@@ -9,11 +10,14 @@ from nabu.preproc import ctf
|
|
9
10
|
from nabu.estimation.distortion import estimate_flat_distortion
|
10
11
|
from nabu.misc.filters import correct_spikes
|
11
12
|
from nabu.preproc.distortion import DistortionCorrection
|
12
|
-
from nabu.cuda.utils import __has_pycuda__,
|
13
|
+
from nabu.cuda.utils import __has_pycuda__, get_cuda_context
|
13
14
|
|
14
|
-
|
15
|
+
__has_cufft__ = False
|
16
|
+
if __has_pycuda__:
|
15
17
|
from nabu.preproc.ctf_cuda import CudaCTFPhaseRetrieval
|
16
|
-
|
18
|
+
|
19
|
+
avail_fft = get_available_fft_implems()
|
20
|
+
__has_cufft__ = len(avail_fft) > 0
|
17
21
|
|
18
22
|
|
19
23
|
@pytest.fixture(scope="class")
|
@@ -39,7 +43,7 @@ def bootstrap_TestCtf(request):
|
|
39
43
|
cls.padded_img_shape_vh = test_data["padded_img_shape_vh"]
|
40
44
|
cls.z1_vh = test_data["z1_vh"]
|
41
45
|
cls.z2 = test_data["z2"]
|
42
|
-
cls.pix_size_det = test_data["pix_size_det"]
|
46
|
+
cls.pix_size_det = test_data["pix_size_det"][()]
|
43
47
|
cls.length_scale = test_data["length_scale"]
|
44
48
|
cls.wavelength = test_data["wave_length"]
|
45
49
|
cls.remove_spikes_threshold = test_data["remove_spikes_threshold"]
|
@@ -174,7 +178,7 @@ class TestCtf:
|
|
174
178
|
phase = ctf_filter.retrieve_phase(img)
|
175
179
|
|
176
180
|
message = "retrieved phase and reference result differ beyond the accepted tolerance"
|
177
|
-
assert np.abs(phase - self.expected_result).max() < self.abs_tol * (
|
181
|
+
assert np.abs(phase - self.expected_result).max() < 10 * self.abs_tol * (
|
178
182
|
np.abs(self.expected_result).mean()
|
179
183
|
), message
|
180
184
|
|
@@ -219,7 +223,7 @@ class TestCtf:
|
|
219
223
|
phase_fft = ctf_fft.retrieve_phase(img)
|
220
224
|
self.check_result(phase_r2c, self.ref_plain, "Something wrong with CtfFilter-FFT")
|
221
225
|
|
222
|
-
@pytest.mark.skipif(not (__has_pycuda__ and __has_cufft__), reason="pycuda and scikit-cuda")
|
226
|
+
@pytest.mark.skipif(not (__has_pycuda__ and __has_cufft__), reason="pycuda and (scikit-cuda or vkfft)")
|
223
227
|
def test_cuda_ctf(self):
|
224
228
|
data = nabu_get_data("brain_phantom.npz")["data"]
|
225
229
|
delta_beta = 50.0
|
@@ -243,7 +247,7 @@ class TestCtf:
|
|
243
247
|
)
|
244
248
|
ref = ctf_filter.retrieve_phase(data)
|
245
249
|
|
246
|
-
d_data =
|
250
|
+
d_data = cuda_ctf_filter.cuda_processing.to_device("_d_data", data)
|
247
251
|
res = cuda_ctf_filter.retrieve_phase(d_data).get()
|
248
252
|
err_max = np.max(np.abs(res - ref))
|
249
253
|
|