nabu 2025.1.0.dev14__py3-none-any.whl → 2025.1.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doc/doc_config.py +32 -0
- nabu/__init__.py +1 -1
- nabu/app/cast_volume.py +9 -1
- nabu/app/cli_configs.py +80 -3
- nabu/app/estimate_motion.py +54 -0
- nabu/app/multicor.py +2 -4
- nabu/app/pcaflats.py +116 -0
- nabu/app/reconstruct.py +1 -7
- nabu/app/reduce_dark_flat.py +5 -2
- nabu/estimation/cor.py +1 -1
- nabu/estimation/motion.py +557 -0
- nabu/estimation/tests/test_motion_estimation.py +471 -0
- nabu/estimation/tilt.py +1 -1
- nabu/estimation/translation.py +47 -1
- nabu/io/cast_volume.py +100 -13
- nabu/io/reader.py +32 -1
- nabu/io/tests/test_remove_volume.py +152 -0
- nabu/pipeline/config_validators.py +42 -43
- nabu/pipeline/estimators.py +255 -0
- nabu/pipeline/fullfield/chunked.py +67 -43
- nabu/pipeline/fullfield/chunked_cuda.py +5 -2
- nabu/pipeline/fullfield/nabu_config.py +20 -14
- nabu/pipeline/fullfield/processconfig.py +17 -3
- nabu/pipeline/fullfield/reconstruction.py +4 -1
- nabu/pipeline/params.py +12 -0
- nabu/pipeline/tests/test_estimators.py +240 -3
- nabu/preproc/ccd.py +53 -3
- nabu/preproc/flatfield.py +306 -1
- nabu/preproc/shift.py +3 -1
- nabu/preproc/tests/test_pcaflats.py +154 -0
- nabu/processing/rotation_cuda.py +3 -1
- nabu/processing/tests/test_rotation.py +4 -2
- nabu/reconstruction/astra.py +245 -0
- nabu/reconstruction/fbp.py +7 -0
- nabu/reconstruction/fbp_base.py +31 -7
- nabu/reconstruction/fbp_opencl.py +8 -0
- nabu/reconstruction/filtering_opencl.py +2 -0
- nabu/reconstruction/mlem.py +47 -13
- nabu/reconstruction/tests/test_filtering.py +13 -2
- nabu/reconstruction/tests/test_mlem.py +91 -62
- nabu/resources/dataset_analyzer.py +144 -20
- nabu/resources/nxflatfield.py +101 -35
- nabu/resources/tests/test_nxflatfield.py +1 -1
- nabu/resources/utils.py +16 -10
- nabu/stitching/alignment.py +7 -7
- nabu/stitching/config.py +22 -20
- nabu/stitching/definitions.py +2 -2
- nabu/stitching/overlap.py +4 -4
- nabu/stitching/sample_normalization.py +5 -5
- nabu/stitching/stitcher/post_processing.py +5 -3
- nabu/stitching/stitcher/pre_processing.py +24 -20
- nabu/stitching/tests/test_config.py +3 -3
- nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
- nabu/stitching/tests/test_z_postprocessing_stitching.py +2 -2
- nabu/stitching/tests/test_z_preprocessing_stitching.py +23 -20
- nabu/stitching/utils/utils.py +7 -7
- nabu/testutils.py +1 -4
- nabu/utils.py +13 -0
- {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/METADATA +3 -4
- {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/RECORD +64 -57
- {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/WHEEL +1 -1
- {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/entry_points.txt +2 -1
- nabu/app/correct_rot.py +0 -62
- {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/licenses/LICENSE +0 -0
- {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,23 @@
|
|
1
1
|
import os
|
2
|
+
from tempfile import TemporaryDirectory
|
2
3
|
import pytest
|
3
4
|
import numpy as np
|
4
|
-
from
|
5
|
-
from
|
5
|
+
from pint import get_application_registry
|
6
|
+
from nxtomo import NXtomo
|
7
|
+
from nabu.testutils import utilstest, __do_long_tests__, get_data
|
8
|
+
from nabu.resources.dataset_analyzer import HDF5DatasetAnalyzer, analyze_dataset, ImageKey
|
6
9
|
from nabu.resources.nxflatfield import update_dataset_info_flats_darks
|
7
10
|
from nabu.resources.utils import extract_parameters
|
8
|
-
from nabu.pipeline.estimators import CompositeCOREstimator
|
11
|
+
from nabu.pipeline.estimators import CompositeCOREstimator, TranslationsEstimator
|
9
12
|
from nabu.pipeline.config import parse_nabu_config_file
|
10
13
|
from nabu.pipeline.estimators import SinoCORFinder, CORFinder
|
11
14
|
|
15
|
+
from nabu.estimation.tests.test_motion_estimation import (
|
16
|
+
check_motion_estimation,
|
17
|
+
project_volume,
|
18
|
+
_create_translations_vector,
|
19
|
+
)
|
20
|
+
|
12
21
|
|
13
22
|
#
|
14
23
|
# Test CoR estimation with "composite-coarse-to-fine" (aka "near" in the legacy system vocable)
|
@@ -119,3 +128,231 @@ class TestCorNearPos:
|
|
119
128
|
cor = finder.find_cor()
|
120
129
|
message = f"Computed CoR {cor} and expected CoR {self.true_cor} do not coincide. Near_pos options was set to {cor_options.get('near_pos',None)}."
|
121
130
|
assert np.isclose(self.true_cor + 0.5, cor, atol=self.abs_tol), message
|
131
|
+
|
132
|
+
|
133
|
+
def _add_fake_flats_and_dark_to_data(data, n_darks=10, n_flats=21, dark_val=1, flat_val=3):
|
134
|
+
img_shape = data.shape[1:]
|
135
|
+
# Use constant darks/flats, to avoid "reduction" (mean/median) issues
|
136
|
+
fake_darks = np.ones((n_darks,) + img_shape, dtype=np.uint16) * dark_val
|
137
|
+
fake_flats = np.ones((n_flats,) + img_shape, dtype=np.uint16) * flat_val
|
138
|
+
return data * (fake_flats[0, 0, 0] - fake_darks[0, 0, 0]) + fake_darks[0, 0, 0], fake_darks, fake_flats
|
139
|
+
|
140
|
+
|
141
|
+
def _generate_nx_for_180_dataset(volume, output_file_path, n_darks=10, n_flats=21):
|
142
|
+
|
143
|
+
n_angles = 250
|
144
|
+
cor = -10
|
145
|
+
|
146
|
+
alpha_x = 4
|
147
|
+
beta_x = 3
|
148
|
+
alpha_y = -5
|
149
|
+
beta_y = 10
|
150
|
+
beta_z = 0
|
151
|
+
orig_det_dist = 0
|
152
|
+
|
153
|
+
angles0 = np.linspace(0, np.pi, n_angles, False)
|
154
|
+
return_angles = np.deg2rad([180.0, 135.0, 90.0, 45.0, 0.0])
|
155
|
+
angles = np.hstack([angles0, return_angles]).ravel()
|
156
|
+
a = np.arange(angles0.size + return_angles.size) / angles0.size
|
157
|
+
|
158
|
+
tx = _create_translations_vector(a, alpha_x, beta_x)
|
159
|
+
ty = _create_translations_vector(a, alpha_y, beta_y)
|
160
|
+
tz = _create_translations_vector(a, 0, beta_z)
|
161
|
+
|
162
|
+
sinos = project_volume(volume, angles, -tx, -ty, -tz, cor=-cor, orig_det_dist=orig_det_dist)
|
163
|
+
data = np.moveaxis(sinos, 1, 0)
|
164
|
+
|
165
|
+
sample_motion_xy = np.stack([-tx, ty], axis=1)
|
166
|
+
sample_motion_z = -tz
|
167
|
+
angles_deg = np.degrees(angles0)
|
168
|
+
return_angles_deg = np.degrees(return_angles)
|
169
|
+
n_return_radios = len(return_angles_deg)
|
170
|
+
n_radios = data.shape[0] - n_return_radios
|
171
|
+
|
172
|
+
ureg = get_application_registry()
|
173
|
+
fake_raw_data, darks, flats = _add_fake_flats_and_dark_to_data(data, n_darks=n_darks, n_flats=n_flats)
|
174
|
+
|
175
|
+
nxtomo = NXtomo()
|
176
|
+
nxtomo.instrument.detector.data = np.concatenate(
|
177
|
+
[
|
178
|
+
darks,
|
179
|
+
flats,
|
180
|
+
fake_raw_data, # radios + return radios (in float32 !)
|
181
|
+
]
|
182
|
+
)
|
183
|
+
image_key_control = np.concatenate(
|
184
|
+
[
|
185
|
+
[ImageKey.DARK_FIELD.value] * n_darks,
|
186
|
+
[ImageKey.FLAT_FIELD.value] * n_flats,
|
187
|
+
[ImageKey.PROJECTION.value] * n_radios,
|
188
|
+
[ImageKey.ALIGNMENT.value] * n_return_radios,
|
189
|
+
]
|
190
|
+
)
|
191
|
+
nxtomo.instrument.detector.image_key_control = image_key_control
|
192
|
+
|
193
|
+
rotation_angle = np.concatenate(
|
194
|
+
[np.zeros(n_darks, dtype="f"), np.zeros(n_flats, dtype="f"), angles_deg, return_angles_deg]
|
195
|
+
)
|
196
|
+
nxtomo.sample.rotation_angle = rotation_angle * ureg.degree
|
197
|
+
nxtomo.instrument.detector.field_of_view = "Full"
|
198
|
+
nxtomo.instrument.detector.x_pixel_size = nxtomo.instrument.detector.y_pixel_size = 1 * ureg.micrometer
|
199
|
+
nxtomo.save(file_path=output_file_path, data_path="entry", overwrite=True)
|
200
|
+
|
201
|
+
return sample_motion_xy, sample_motion_z, cor
|
202
|
+
|
203
|
+
|
204
|
+
def _generate_nx_for_360_dataset(volume, output_file_path, n_darks=10, n_flats=21):
|
205
|
+
|
206
|
+
n_angles = 250
|
207
|
+
cor = -5.5
|
208
|
+
|
209
|
+
alpha_x = -2
|
210
|
+
beta_x = 7.0
|
211
|
+
alpha_y = -2
|
212
|
+
beta_y = 3
|
213
|
+
beta_z = 100
|
214
|
+
orig_det_dist = 0
|
215
|
+
|
216
|
+
angles = np.linspace(0, 2 * np.pi, n_angles, False)
|
217
|
+
a = np.linspace(0, 1, angles.size, endpoint=False) # theta/theta_max
|
218
|
+
|
219
|
+
tx = _create_translations_vector(a, alpha_x, beta_x)
|
220
|
+
ty = _create_translations_vector(a, alpha_y, beta_y)
|
221
|
+
tz = _create_translations_vector(a, 0, beta_z)
|
222
|
+
|
223
|
+
sinos = project_volume(volume, angles, -tx, -ty, -tz, cor=-cor, orig_det_dist=orig_det_dist)
|
224
|
+
data = np.moveaxis(sinos, 1, 0)
|
225
|
+
|
226
|
+
sample_motion_xy = np.stack([-tx, ty], axis=1)
|
227
|
+
sample_motion_z = -tz
|
228
|
+
angles_deg = np.degrees(angles)
|
229
|
+
|
230
|
+
ureg = get_application_registry()
|
231
|
+
|
232
|
+
fake_raw_data, darks, flats = _add_fake_flats_and_dark_to_data(data, n_darks=n_darks, n_flats=n_flats)
|
233
|
+
|
234
|
+
nxtomo = NXtomo()
|
235
|
+
nxtomo.instrument.detector.data = np.concatenate([darks, flats, fake_raw_data]) # in float32 !
|
236
|
+
|
237
|
+
image_key_control = np.concatenate(
|
238
|
+
[
|
239
|
+
[ImageKey.DARK_FIELD.value] * n_darks,
|
240
|
+
[ImageKey.FLAT_FIELD.value] * n_flats,
|
241
|
+
[ImageKey.PROJECTION.value] * data.shape[0],
|
242
|
+
]
|
243
|
+
)
|
244
|
+
nxtomo.instrument.detector.image_key_control = image_key_control
|
245
|
+
|
246
|
+
rotation_angle = np.concatenate(
|
247
|
+
[
|
248
|
+
np.zeros(n_darks, dtype="f"),
|
249
|
+
np.zeros(n_flats, dtype="f"),
|
250
|
+
angles_deg,
|
251
|
+
]
|
252
|
+
)
|
253
|
+
nxtomo.sample.rotation_angle = rotation_angle * ureg.degree
|
254
|
+
nxtomo.instrument.detector.field_of_view = "Full"
|
255
|
+
nxtomo.instrument.detector.x_pixel_size = nxtomo.instrument.detector.y_pixel_size = 1 * ureg.micrometer
|
256
|
+
nxtomo.save(file_path=output_file_path, data_path="entry", overwrite=True)
|
257
|
+
|
258
|
+
return sample_motion_xy, sample_motion_z, cor
|
259
|
+
|
260
|
+
|
261
|
+
@pytest.fixture(scope="class")
|
262
|
+
def setup_test_motion_estimator(request):
|
263
|
+
cls = request.cls
|
264
|
+
cls.volume = get_data("motion/mri_volume_subsampled.npy")
|
265
|
+
|
266
|
+
|
267
|
+
@pytest.mark.skipif(not (__do_long_tests__), reason="need environment variable NABU_LONG_TESTS=1")
|
268
|
+
@pytest.mark.usefixtures("setup_test_motion_estimator")
|
269
|
+
class TestMotionEstimator:
|
270
|
+
|
271
|
+
def _setup(self, tmpdir):
|
272
|
+
# pytest uses some weird data structure for "tmpdir"
|
273
|
+
if not (isinstance(tmpdir, str)):
|
274
|
+
tmpdir = str(tmpdir)
|
275
|
+
#
|
276
|
+
if getattr(self, "volume", None) is None:
|
277
|
+
self.volume = get_data("motion/mri_volume_subsampled.npy")
|
278
|
+
|
279
|
+
def test_estimate_motion_360_dataset(self, tmpdir, verbose=False):
|
280
|
+
self._setup(tmpdir)
|
281
|
+
nx_file_path = os.path.join(tmpdir, "mri_projected_360_motion.nx")
|
282
|
+
sample_motion_xy, sample_motion_z, cor = _generate_nx_for_360_dataset(self.volume, nx_file_path)
|
283
|
+
|
284
|
+
dataset_info = analyze_dataset(nx_file_path)
|
285
|
+
|
286
|
+
translations_estimator = TranslationsEstimator(
|
287
|
+
dataset_info, do_flatfield=True, rot_center=cor, angular_subsampling=5, deg_xy=2, deg_z=2
|
288
|
+
)
|
289
|
+
estimated_shifts_h, estimated_shifts_v, estimated_cor = translations_estimator.estimate_motion()
|
290
|
+
|
291
|
+
s = translations_estimator.angular_subsampling
|
292
|
+
if verbose:
|
293
|
+
translations_estimator.motion_estimator.plot_detector_shifts(cor=cor)
|
294
|
+
translations_estimator.motion_estimator.plot_movements(
|
295
|
+
cor=cor,
|
296
|
+
angles_rad=dataset_info.rotation_angles[::s],
|
297
|
+
gt_xy=sample_motion_xy[::s, :],
|
298
|
+
gt_z=sample_motion_z[::s],
|
299
|
+
)
|
300
|
+
check_motion_estimation(
|
301
|
+
translations_estimator.motion_estimator,
|
302
|
+
dataset_info.rotation_angles[::s],
|
303
|
+
cor,
|
304
|
+
sample_motion_xy[::s, :],
|
305
|
+
sample_motion_z[::s],
|
306
|
+
fit_error_shifts_tol_vu=(0.2, 0.2),
|
307
|
+
fit_error_det_tol_vu=(1e-5, 5e-2),
|
308
|
+
fit_error_tol_xyz=(0.05, 0.05, 0.05),
|
309
|
+
fit_error_det_all_angles_tol_vu=(1e-5, 0.05),
|
310
|
+
)
|
311
|
+
|
312
|
+
def test_estimate_motion_180_dataset(self, tmpdir, verbose=False):
|
313
|
+
self._setup(tmpdir)
|
314
|
+
nx_file_path = os.path.join(tmpdir, "mri_projected_180_motion.nx")
|
315
|
+
|
316
|
+
sample_motion_xy, sample_motion_z, cor = _generate_nx_for_180_dataset(self.volume, nx_file_path)
|
317
|
+
|
318
|
+
dataset_info = analyze_dataset(nx_file_path)
|
319
|
+
|
320
|
+
translations_estimator = TranslationsEstimator(
|
321
|
+
dataset_info,
|
322
|
+
do_flatfield=True,
|
323
|
+
rot_center=cor,
|
324
|
+
angular_subsampling=2,
|
325
|
+
deg_xy=2,
|
326
|
+
deg_z=2,
|
327
|
+
shifts_estimator="DetectorTranslationAlongBeam",
|
328
|
+
)
|
329
|
+
estimated_shifts_h, estimated_shifts_v, estimated_cor = translations_estimator.estimate_motion()
|
330
|
+
|
331
|
+
if verbose:
|
332
|
+
translations_estimator.motion_estimator.plot_detector_shifts(cor=cor)
|
333
|
+
translations_estimator.motion_estimator.plot_movements(
|
334
|
+
cor=cor,
|
335
|
+
angles_rad=dataset_info.rotation_angles,
|
336
|
+
gt_xy=sample_motion_xy[: dataset_info.n_angles],
|
337
|
+
gt_z=sample_motion_z[: dataset_info.n_angles],
|
338
|
+
)
|
339
|
+
|
340
|
+
check_motion_estimation(
|
341
|
+
translations_estimator.motion_estimator,
|
342
|
+
dataset_info.rotation_angles,
|
343
|
+
cor,
|
344
|
+
sample_motion_xy,
|
345
|
+
sample_motion_z,
|
346
|
+
fit_error_shifts_tol_vu=(0.02, 0.1),
|
347
|
+
fit_error_det_tol_vu=(1e-2, 0.5),
|
348
|
+
fit_error_tol_xyz=(0.5, 2, 1e-2),
|
349
|
+
fit_error_det_all_angles_tol_vu=(1e-2, 2),
|
350
|
+
)
|
351
|
+
|
352
|
+
|
353
|
+
if __name__ == "__main__":
|
354
|
+
|
355
|
+
T = TestMotionEstimator()
|
356
|
+
with TemporaryDirectory(suffix="_motion", prefix="nabu_testdata") as tmpdir:
|
357
|
+
T.test_estimate_motion_360_dataset(tmpdir, verbose=True)
|
358
|
+
T.test_estimate_motion_180_dataset(tmpdir, verbose=True)
|
nabu/preproc/ccd.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import numpy as np
|
2
2
|
from ..utils import check_supported
|
3
|
+
from scipy.ndimage import binary_dilation
|
3
4
|
from silx.math.medianfilter import medfilt2d
|
4
5
|
|
5
6
|
|
@@ -13,6 +14,7 @@ class CCDFilter:
|
|
13
14
|
def __init__(
|
14
15
|
self,
|
15
16
|
radios_shape: tuple,
|
17
|
+
kernel_size: int = 3,
|
16
18
|
correction_type: str = "median_clip",
|
17
19
|
median_clip_thresh: float = 0.1,
|
18
20
|
abs_diff=False,
|
@@ -26,6 +28,9 @@ class CCDFilter:
|
|
26
28
|
radios_shape: tuple
|
27
29
|
A tuple describing the shape of the radios stack, in the form
|
28
30
|
`(n_radios, n_z, n_x)`.
|
31
|
+
kernel_size: int
|
32
|
+
Size of the kernel for the median filter.
|
33
|
+
Default is 3.
|
29
34
|
correction_type: str
|
30
35
|
Correction type for radios ("median_clip", "sigma_clip", ...)
|
31
36
|
median_clip_thresh: float, optional
|
@@ -48,6 +53,7 @@ class CCDFilter:
|
|
48
53
|
then this pixel value is set to the median value.
|
49
54
|
"""
|
50
55
|
self._set_radios_shape(radios_shape)
|
56
|
+
self.kernel_size = kernel_size
|
51
57
|
check_supported(correction_type, self._supported_ccd_corrections, "CCD correction mode")
|
52
58
|
self.correction_type = correction_type
|
53
59
|
self.median_clip_thresh = median_clip_thresh
|
@@ -67,11 +73,11 @@ class CCDFilter:
|
|
67
73
|
self.shape = (n_z, n_x)
|
68
74
|
|
69
75
|
@staticmethod
|
70
|
-
def median_filter(img):
|
76
|
+
def median_filter(img, kernel_size=3):
|
71
77
|
"""
|
72
78
|
Perform a median filtering on an image.
|
73
79
|
"""
|
74
|
-
return medfilt2d(img, (
|
80
|
+
return medfilt2d(img, (kernel_size, kernel_size), mode="reflect")
|
75
81
|
|
76
82
|
def median_clip_mask(self, img, return_medians=False):
|
77
83
|
"""
|
@@ -85,7 +91,7 @@ class CCDFilter:
|
|
85
91
|
return_medians: bool, optional
|
86
92
|
Whether to return the median values additionally to the mask.
|
87
93
|
"""
|
88
|
-
median_values = self.median_filter(img)
|
94
|
+
median_values = self.median_filter(img, kernel_size=self.kernel_size)
|
89
95
|
if not self.abs_diff:
|
90
96
|
invalid_mask = img >= median_values + self.median_clip_thresh
|
91
97
|
else:
|
@@ -124,6 +130,50 @@ class CCDFilter:
|
|
124
130
|
|
125
131
|
return output
|
126
132
|
|
133
|
+
def dezinger_correction(self, radios, dark=None, nsigma=5, output=None):
|
134
|
+
"""
|
135
|
+
Compute the median clip correction on a radios stack, and propagates the invalid pixels into vert and horiz directions.
|
136
|
+
|
137
|
+
Parameters
|
138
|
+
----------
|
139
|
+
radios: numpy.ndarray
|
140
|
+
A radios stack.
|
141
|
+
dark: numpy.ndarray, optional
|
142
|
+
A dark image. Default is None. If not None, it is subtracted from the radios.
|
143
|
+
nsigma: float, optional
|
144
|
+
Number of standard deviations to use for the zinger detection.
|
145
|
+
Default is 5.
|
146
|
+
output: numpy.ndarray, optional
|
147
|
+
Output array
|
148
|
+
"""
|
149
|
+
if radios.shape[1:] != self.radios_shape[1:]:
|
150
|
+
raise ValueError(f"Expected radios shape {self.radios_shape}, got {radios.shape}")
|
151
|
+
|
152
|
+
if output is None:
|
153
|
+
output = np.copy(radios)
|
154
|
+
else:
|
155
|
+
output[:] = radios[:]
|
156
|
+
|
157
|
+
n_radios = radios.shape[0]
|
158
|
+
for i in range(n_radios):
|
159
|
+
if dark is None:
|
160
|
+
dimg = radios[i]
|
161
|
+
elif dark.shape == radios.shape[1:]:
|
162
|
+
dimg = radios[i] - dark
|
163
|
+
else:
|
164
|
+
raise ValueError("Dark image shape does not match radios shape.")
|
165
|
+
|
166
|
+
dimg = radios[i] - dark
|
167
|
+
med = self.median_filter(dimg, self.kernel_size)
|
168
|
+
err = dimg - med
|
169
|
+
ds0 = err.std()
|
170
|
+
msk = err > ds0 * nsigma
|
171
|
+
gromsk = binary_dilation(msk)
|
172
|
+
|
173
|
+
output[i] = np.where(gromsk, med, radios[i])
|
174
|
+
|
175
|
+
return output
|
176
|
+
|
127
177
|
|
128
178
|
class Log:
|
129
179
|
"""
|
nabu/preproc/flatfield.py
CHANGED
@@ -1,8 +1,11 @@
|
|
1
|
+
import os
|
1
2
|
from multiprocessing.pool import ThreadPool
|
2
3
|
from bisect import bisect_left
|
3
4
|
import numpy as np
|
5
|
+
from tomoscan.io import HDF5File
|
4
6
|
from ..io.reader import load_images_from_dataurl_dict
|
5
7
|
from ..utils import check_supported, deprecated_class, get_num_threads
|
8
|
+
from .ccd import CCDFilter
|
6
9
|
|
7
10
|
|
8
11
|
class FlatFieldArrays:
|
@@ -102,7 +105,7 @@ class FlatFieldArrays:
|
|
102
105
|
self._precompute_flats_indices_weights()
|
103
106
|
self._configure_srcurrent_normalization(radios_srcurrent, flats_srcurrent)
|
104
107
|
self.distortion_correction = distortion_correction
|
105
|
-
self.n_threads =
|
108
|
+
self.n_threads = max(1, get_num_threads(n_threads) // 2)
|
106
109
|
|
107
110
|
def _set_parameters(self, radios_shape, radios_indices, interpolation, nan_value):
|
108
111
|
self._set_radios_shape(radios_shape)
|
@@ -450,3 +453,305 @@ class FlatFieldDataUrls(FlatField):
|
|
450
453
|
radios_srcurrent=radios_srcurrent,
|
451
454
|
flats_srcurrent=flats_srcurrent,
|
452
455
|
)
|
456
|
+
|
457
|
+
|
458
|
+
class PCAFlatsNormalizer:
|
459
|
+
"""This class implement a flatfield normalization based on a PCA of a series of acauired flatfields.
|
460
|
+
The PCA decomposition is handled by a PCAFlatsDecomposer object.
|
461
|
+
|
462
|
+
This implementation was proposed by Jailin C. et al in https://doi.org/10.1107/S1600577516015812.
|
463
|
+
|
464
|
+
Code initially written by ID11 @ ESRF staff.
|
465
|
+
Jonathan Wright - Implementation based on research paper
|
466
|
+
Pedro D. Resende - Added saving and loading from file capabilities
|
467
|
+
|
468
|
+
Jerome Lesaint - Integrated the solution in Nabu.
|
469
|
+
|
470
|
+
"""
|
471
|
+
|
472
|
+
def __init__(self, components, dark, mean):
|
473
|
+
"""Initializes all variables needed to perform the flatfield normalization.
|
474
|
+
|
475
|
+
Parameters
|
476
|
+
-----------
|
477
|
+
components: ndarray
|
478
|
+
The components of the PCA decomposition.
|
479
|
+
dark: ndarray
|
480
|
+
The dark image. Should be one single 2D image.
|
481
|
+
mean: ndarray
|
482
|
+
The mean image of the series of flats.
|
483
|
+
"""
|
484
|
+
ones = np.ones_like(components[0], dtype=np.float32)[np.newaxis] # This comp will account for I0
|
485
|
+
self.components = np.concatenate([ones, components], axis=0)
|
486
|
+
self.dark = dark
|
487
|
+
self.mean = mean
|
488
|
+
self.n_threads = max(1, get_num_threads() // 2)
|
489
|
+
self._setmask()
|
490
|
+
self.ccdfilter = CCDFilter(mean.shape)
|
491
|
+
|
492
|
+
def _form_lsq_matrix(self):
|
493
|
+
"""This function form the Least Square matrix, based on the flats components and the mask."""
|
494
|
+
self.Amat = np.stack([gg for gg in self.g], axis=1) # JL: this is the matrix for the fit
|
495
|
+
|
496
|
+
def _setmask(self, prop=0.125):
|
497
|
+
"""Sets the mask to select where the model is going to be fitted.
|
498
|
+
|
499
|
+
Parameters
|
500
|
+
----------
|
501
|
+
prop: float, default: 0.125
|
502
|
+
The proportion of the image width to take on each side of the image as a mask.
|
503
|
+
|
504
|
+
By default it sets the strips on each side of the frame in the form:
|
505
|
+
mask[:, lim:] = True
|
506
|
+
mask[:, -lim:] = True
|
507
|
+
Where lim = prop * flat.shape[1]
|
508
|
+
|
509
|
+
If you need a custom mask, see update_mask() method.
|
510
|
+
"""
|
511
|
+
|
512
|
+
lim = int(prop * self.mean.shape[1])
|
513
|
+
self.mask = np.zeros(self.mean.shape, dtype=bool)
|
514
|
+
self.mask[:, :lim] = True
|
515
|
+
self.mask[:, -lim:] = True
|
516
|
+
self.g = []
|
517
|
+
for component in self.components:
|
518
|
+
self.g.append(component[self.mask])
|
519
|
+
|
520
|
+
self._form_lsq_matrix()
|
521
|
+
|
522
|
+
def update_mask(self, mask: np.ndarray):
|
523
|
+
"""Method to update the mask with a custom mask in the form of a boolean 2+D array.
|
524
|
+
|
525
|
+
Paramters
|
526
|
+
=========
|
527
|
+
mask: np.ndarray of Boolean;
|
528
|
+
The array of boolean allows the selection of the region of the image that will be used to fit against the components.
|
529
|
+
|
530
|
+
It will set the mask, replacing the standard mask created with setmask().
|
531
|
+
"""
|
532
|
+
if mask.dtype == bool:
|
533
|
+
self.mask = mask
|
534
|
+
self.g = []
|
535
|
+
for component in self.components:
|
536
|
+
self.g.append(component[self.mask])
|
537
|
+
|
538
|
+
self._form_lsq_matrix()
|
539
|
+
else:
|
540
|
+
raise TypeError("Not a boolean array. Will keep the default mask")
|
541
|
+
|
542
|
+
def normalize_radios(self, projections, mask=None, prop=0.125):
|
543
|
+
"""This is to keep the flatfield API in the pipeline."""
|
544
|
+
self.correct_stack(projections, mask=mask, prop=prop)
|
545
|
+
|
546
|
+
def correct_stack(
|
547
|
+
self,
|
548
|
+
projections: np.ndarray,
|
549
|
+
mask: np.ndarray = None,
|
550
|
+
prop: float = 0.125,
|
551
|
+
):
|
552
|
+
"""This functions normalizes the stack of projections.
|
553
|
+
|
554
|
+
Performs correction on a stack of projections based on the calculated decomposition.
|
555
|
+
The normalizations is done in-place. The previous projections before normalization are lost.
|
556
|
+
|
557
|
+
Parameters
|
558
|
+
----------
|
559
|
+
projections: ndarray
|
560
|
+
Stack of projections to normalize.
|
561
|
+
|
562
|
+
prop: float (default: {0.125})
|
563
|
+
Fraction to mask on the horizontal field of view, assuming vertical rotation axis
|
564
|
+
|
565
|
+
mask: np.ndarray (default: None)
|
566
|
+
Custom mask if your data requires it.
|
567
|
+
|
568
|
+
Returns
|
569
|
+
-------
|
570
|
+
corrected projections: np.ndarray
|
571
|
+
Flat field corrected images. Note that the returned projections are exp-transformed to fit the pipeline (e.g. to allow for phase retrieval).
|
572
|
+
"""
|
573
|
+
self.projections = self.ccdfilter.dezinger_correction(projections, self.dark)
|
574
|
+
|
575
|
+
if mask is not None:
|
576
|
+
self.update_mask(mask=mask)
|
577
|
+
else:
|
578
|
+
self._setmask(prop=prop)
|
579
|
+
|
580
|
+
with ThreadPool(self.n_threads) as tp:
|
581
|
+
for i, cor, sol in tp.map(self._readcorrect1, range(len(projections))):
|
582
|
+
# solution[i] = sol
|
583
|
+
projections[i] = np.exp(-cor)
|
584
|
+
|
585
|
+
def _readcorrect1(self, ii):
|
586
|
+
"""Method to allow parallelization of the normalization."""
|
587
|
+
corr, s = self.correctproj(self.projections[ii])
|
588
|
+
return ii, corr, s
|
589
|
+
|
590
|
+
def correctproj(self, projection):
|
591
|
+
"""Performs the correction on one projection of the stack.
|
592
|
+
|
593
|
+
Parameters
|
594
|
+
----------
|
595
|
+
projection: np.ndarray, float
|
596
|
+
Radiograph from the acquisition stack.
|
597
|
+
|
598
|
+
Returns
|
599
|
+
-------
|
600
|
+
The fitted projection.
|
601
|
+
"""
|
602
|
+
logp = np.log(projection.astype(np.float32) - self.dark)
|
603
|
+
corr = self.mean - logp
|
604
|
+
# model to be fitted !!
|
605
|
+
return self.fit(corr)
|
606
|
+
|
607
|
+
def fit(self, corr):
|
608
|
+
"""Fit the (masked) projection to the (masked) components of the PCA decomposition.
|
609
|
+
|
610
|
+
This is for each projection, so worth optimising ...
|
611
|
+
"""
|
612
|
+
y = corr[self.mask]
|
613
|
+
solution = np.linalg.lstsq(self.Amat, y, rcond=None)[0]
|
614
|
+
correction = np.einsum("ijk,i->jk", self.components, solution)
|
615
|
+
return corr - correction, solution
|
616
|
+
|
617
|
+
|
618
|
+
class PCAFlatsDecomposer:
|
619
|
+
"""This class implements a PCA decomposition of a serie of acquired flatfields.
|
620
|
+
The PCA decomposition is used to normalize the projections through a PCAFLatNormalizer object.
|
621
|
+
|
622
|
+
This implementation was proposed by Jailin C. et al in https://doi.org/10.1107/S1600577516015812.
|
623
|
+
|
624
|
+
Code initially written by ID11 @ ESRF staff.
|
625
|
+
Jonathan Wright - Implementation based on research paper
|
626
|
+
Pedro D. Resende - Added saving and loading from file capabilities
|
627
|
+
|
628
|
+
Jerome Lesaint - Integrated the solution in Nabu.
|
629
|
+
"""
|
630
|
+
|
631
|
+
def __init__(self, flats: np.ndarray, darks: np.ndarray, nsigma=3):
|
632
|
+
"""
|
633
|
+
|
634
|
+
Parameters
|
635
|
+
-----------
|
636
|
+
flats: np.ndarray
|
637
|
+
A stack of darks corrected flat field images
|
638
|
+
darks: np.ndarray
|
639
|
+
An image or stack of images of the dark current images of the camera.
|
640
|
+
|
641
|
+
Does the log scaling.
|
642
|
+
Subtracts mean and does eigenvector decomposition.
|
643
|
+
"""
|
644
|
+
self.n_threads = max(1, get_num_threads() // 2)
|
645
|
+
self.flats = np.empty(flats.shape, dtype=np.float32)
|
646
|
+
|
647
|
+
darks = darks.astype(np.float32)
|
648
|
+
if darks.ndim == 3:
|
649
|
+
self.dark = np.median(darks, axis=0)
|
650
|
+
else:
|
651
|
+
self.dark = darks.copy()
|
652
|
+
del darks
|
653
|
+
|
654
|
+
self.nsigma = nsigma
|
655
|
+
self.ccdfilter = CCDFilter(self.dark.shape)
|
656
|
+
self._ccdfilter_and_log(flats) # Log is taken here (after dezinger)
|
657
|
+
|
658
|
+
self.mean = np.mean(self.flats, axis=0) # average
|
659
|
+
|
660
|
+
self.flats = self.flats - self.mean # makes a copy
|
661
|
+
self.cov = self.compute_correlation_matrix()
|
662
|
+
self.compute_pca()
|
663
|
+
self.generate_pca_flats(nsigma=nsigma) # Default nsigma=3
|
664
|
+
|
665
|
+
def __str__(self):
|
666
|
+
|
667
|
+
return f"PCA decomposition from flat images. \nThere are {self.components} components created at {self.sigma} level."
|
668
|
+
|
669
|
+
def _ccdfilter_and_log(self, flats: np.ndarray):
|
670
|
+
"""Dezinger (substract dark, apply median filter) and takes log of the flat stack"""
|
671
|
+
|
672
|
+
self.ccdfilter.dezinger_correction(flats, self.dark, nsigma=self.nsigma)
|
673
|
+
|
674
|
+
with ThreadPool(self.n_threads) as tp:
|
675
|
+
for i, frame in enumerate(tp.map(np.log, flats)):
|
676
|
+
self.flats[i] = frame
|
677
|
+
|
678
|
+
@staticmethod
|
679
|
+
def ccij(args):
|
680
|
+
"""Compute the covariance (img[i]*img[j]).sum() / npixels
|
681
|
+
It is a wrapper for threading.
|
682
|
+
args == i, j, npixels, imgs
|
683
|
+
"""
|
684
|
+
|
685
|
+
i, j, NY, imgs = args
|
686
|
+
return i, j, np.einsum("ij,ij", imgs[i], imgs[j]) / NY
|
687
|
+
|
688
|
+
def compute_correlation_matrix(self):
|
689
|
+
"""Computes an (nflats x nflats) correlation matrix"""
|
690
|
+
|
691
|
+
N = len(self.flats)
|
692
|
+
CC = np.zeros((N, N), float)
|
693
|
+
args = [(i, j, N, self.flats) for i in range(N) for j in range(i + 1)]
|
694
|
+
with ThreadPool(self.n_threads) as tp:
|
695
|
+
for i, j, result in tp.map(self.ccij, args):
|
696
|
+
CC[i, j] = CC[j, i] = result
|
697
|
+
return CC
|
698
|
+
|
699
|
+
def compute_pca(self):
|
700
|
+
"""Gets eigenvectors and eigenvalues and sorts them into order"""
|
701
|
+
|
702
|
+
self.eigenvalues, self.eigenvectors = np.linalg.eigh(
|
703
|
+
self.cov
|
704
|
+
) # Not sure why the eigh is needed. eig should be enough.
|
705
|
+
order = np.argsort(abs(self.eigenvalues))[::-1] # high to low
|
706
|
+
self.eigenvalues = self.eigenvalues[order]
|
707
|
+
self.eigenvectors = self.eigenvectors[:, order]
|
708
|
+
|
709
|
+
def generate_pca_flats(self, nsigma=3):
|
710
|
+
"""Projects the eigenvectors back into image space.
|
711
|
+
|
712
|
+
Parameters
|
713
|
+
----------
|
714
|
+
nsigma: int (default: 3)
|
715
|
+
"""
|
716
|
+
|
717
|
+
self.sigma = nsigma
|
718
|
+
av = abs(self.eigenvalues)
|
719
|
+
N = (av > (av[-2] * nsigma)).sum() # Go for 3 sigma
|
720
|
+
self.components = [
|
721
|
+
None,
|
722
|
+
] * N
|
723
|
+
|
724
|
+
def calculate(ii):
|
725
|
+
calc = np.einsum("i,ijk->jk", self.eigenvectors[:, ii], self.flats)
|
726
|
+
norm = (calc**2).sum()
|
727
|
+
return ii, calc / np.sqrt(norm)
|
728
|
+
|
729
|
+
with ThreadPool(self.n_threads) as tp:
|
730
|
+
for ii, result in tp.map(calculate, range(N)):
|
731
|
+
self.components[ii] = result
|
732
|
+
|
733
|
+
# simple gradients
|
734
|
+
r, c = self.components[0].shape
|
735
|
+
self.components.append(np.outer(np.ones(r), np.linspace(-1 / c, 1 / c, c)))
|
736
|
+
self.components.append(np.outer(np.linspace(-1 / r, 1 / r, r), np.ones(c)))
|
737
|
+
|
738
|
+
def save_decomposition(self, path="PCA_flats.h5", overwrite=True, entry="entry0000"):
|
739
|
+
"""Saves the basic information of a PCA decomposition in view of the normalization of projections.
|
740
|
+
|
741
|
+
Parameters
|
742
|
+
----------
|
743
|
+
path: str (default: "PCA_flats.h5")
|
744
|
+
Full path to the h5 file you want to save your results. It will overwrite!! Be careful.
|
745
|
+
"""
|
746
|
+
|
747
|
+
file_exists = os.path.exists(path)
|
748
|
+
if overwrite or not file_exists:
|
749
|
+
with HDF5File(path, "w") as hout:
|
750
|
+
group = hout.create_group(entry)
|
751
|
+
group["eigenvalues"] = self.eigenvalues
|
752
|
+
group["dark"] = self.dark.astype(np.float32)
|
753
|
+
group["p_mean"] = self.mean.astype(np.float32)
|
754
|
+
group["p_components"] = np.array(self.components)
|
755
|
+
hout.flush()
|
756
|
+
else:
|
757
|
+
raise OSError(f"The file {path} already exists and you chose to NOT overwrite.")
|