nabu 2025.1.0.dev14__py3-none-any.whl → 2025.1.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. doc/doc_config.py +32 -0
  2. nabu/__init__.py +1 -1
  3. nabu/app/cast_volume.py +9 -1
  4. nabu/app/cli_configs.py +80 -3
  5. nabu/app/estimate_motion.py +54 -0
  6. nabu/app/multicor.py +2 -4
  7. nabu/app/pcaflats.py +116 -0
  8. nabu/app/reconstruct.py +1 -7
  9. nabu/app/reduce_dark_flat.py +5 -2
  10. nabu/estimation/cor.py +1 -1
  11. nabu/estimation/motion.py +557 -0
  12. nabu/estimation/tests/test_motion_estimation.py +471 -0
  13. nabu/estimation/tilt.py +1 -1
  14. nabu/estimation/translation.py +47 -1
  15. nabu/io/cast_volume.py +100 -13
  16. nabu/io/reader.py +32 -1
  17. nabu/io/tests/test_remove_volume.py +152 -0
  18. nabu/pipeline/config_validators.py +42 -43
  19. nabu/pipeline/estimators.py +255 -0
  20. nabu/pipeline/fullfield/chunked.py +67 -43
  21. nabu/pipeline/fullfield/chunked_cuda.py +5 -2
  22. nabu/pipeline/fullfield/nabu_config.py +20 -14
  23. nabu/pipeline/fullfield/processconfig.py +17 -3
  24. nabu/pipeline/fullfield/reconstruction.py +4 -1
  25. nabu/pipeline/params.py +12 -0
  26. nabu/pipeline/tests/test_estimators.py +240 -3
  27. nabu/preproc/ccd.py +53 -3
  28. nabu/preproc/flatfield.py +306 -1
  29. nabu/preproc/shift.py +3 -1
  30. nabu/preproc/tests/test_pcaflats.py +154 -0
  31. nabu/processing/rotation_cuda.py +3 -1
  32. nabu/processing/tests/test_rotation.py +4 -2
  33. nabu/reconstruction/astra.py +245 -0
  34. nabu/reconstruction/fbp.py +7 -0
  35. nabu/reconstruction/fbp_base.py +31 -7
  36. nabu/reconstruction/fbp_opencl.py +8 -0
  37. nabu/reconstruction/filtering_opencl.py +2 -0
  38. nabu/reconstruction/mlem.py +47 -13
  39. nabu/reconstruction/tests/test_filtering.py +13 -2
  40. nabu/reconstruction/tests/test_mlem.py +91 -62
  41. nabu/resources/dataset_analyzer.py +144 -20
  42. nabu/resources/nxflatfield.py +101 -35
  43. nabu/resources/tests/test_nxflatfield.py +1 -1
  44. nabu/resources/utils.py +16 -10
  45. nabu/stitching/alignment.py +7 -7
  46. nabu/stitching/config.py +22 -20
  47. nabu/stitching/definitions.py +2 -2
  48. nabu/stitching/overlap.py +4 -4
  49. nabu/stitching/sample_normalization.py +5 -5
  50. nabu/stitching/stitcher/post_processing.py +5 -3
  51. nabu/stitching/stitcher/pre_processing.py +24 -20
  52. nabu/stitching/tests/test_config.py +3 -3
  53. nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
  54. nabu/stitching/tests/test_z_postprocessing_stitching.py +2 -2
  55. nabu/stitching/tests/test_z_preprocessing_stitching.py +23 -20
  56. nabu/stitching/utils/utils.py +7 -7
  57. nabu/testutils.py +1 -4
  58. nabu/utils.py +13 -0
  59. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/METADATA +3 -4
  60. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/RECORD +64 -57
  61. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/WHEEL +1 -1
  62. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/entry_points.txt +2 -1
  63. nabu/app/correct_rot.py +0 -62
  64. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/licenses/LICENSE +0 -0
  65. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,23 @@
1
1
  import os
2
+ from tempfile import TemporaryDirectory
2
3
  import pytest
3
4
  import numpy as np
4
- from nabu.testutils import utilstest, __do_long_tests__
5
- from nabu.resources.dataset_analyzer import HDF5DatasetAnalyzer, analyze_dataset
5
+ from pint import get_application_registry
6
+ from nxtomo import NXtomo
7
+ from nabu.testutils import utilstest, __do_long_tests__, get_data
8
+ from nabu.resources.dataset_analyzer import HDF5DatasetAnalyzer, analyze_dataset, ImageKey
6
9
  from nabu.resources.nxflatfield import update_dataset_info_flats_darks
7
10
  from nabu.resources.utils import extract_parameters
8
- from nabu.pipeline.estimators import CompositeCOREstimator
11
+ from nabu.pipeline.estimators import CompositeCOREstimator, TranslationsEstimator
9
12
  from nabu.pipeline.config import parse_nabu_config_file
10
13
  from nabu.pipeline.estimators import SinoCORFinder, CORFinder
11
14
 
15
+ from nabu.estimation.tests.test_motion_estimation import (
16
+ check_motion_estimation,
17
+ project_volume,
18
+ _create_translations_vector,
19
+ )
20
+
12
21
 
13
22
  #
14
23
  # Test CoR estimation with "composite-coarse-to-fine" (aka "near" in the legacy system vocable)
@@ -119,3 +128,231 @@ class TestCorNearPos:
119
128
  cor = finder.find_cor()
120
129
  message = f"Computed CoR {cor} and expected CoR {self.true_cor} do not coincide. Near_pos options was set to {cor_options.get('near_pos',None)}."
121
130
  assert np.isclose(self.true_cor + 0.5, cor, atol=self.abs_tol), message
131
+
132
+
133
+ def _add_fake_flats_and_dark_to_data(data, n_darks=10, n_flats=21, dark_val=1, flat_val=3):
134
+ img_shape = data.shape[1:]
135
+ # Use constant darks/flats, to avoid "reduction" (mean/median) issues
136
+ fake_darks = np.ones((n_darks,) + img_shape, dtype=np.uint16) * dark_val
137
+ fake_flats = np.ones((n_flats,) + img_shape, dtype=np.uint16) * flat_val
138
+ return data * (fake_flats[0, 0, 0] - fake_darks[0, 0, 0]) + fake_darks[0, 0, 0], fake_darks, fake_flats
139
+
140
+
141
+ def _generate_nx_for_180_dataset(volume, output_file_path, n_darks=10, n_flats=21):
142
+
143
+ n_angles = 250
144
+ cor = -10
145
+
146
+ alpha_x = 4
147
+ beta_x = 3
148
+ alpha_y = -5
149
+ beta_y = 10
150
+ beta_z = 0
151
+ orig_det_dist = 0
152
+
153
+ angles0 = np.linspace(0, np.pi, n_angles, False)
154
+ return_angles = np.deg2rad([180.0, 135.0, 90.0, 45.0, 0.0])
155
+ angles = np.hstack([angles0, return_angles]).ravel()
156
+ a = np.arange(angles0.size + return_angles.size) / angles0.size
157
+
158
+ tx = _create_translations_vector(a, alpha_x, beta_x)
159
+ ty = _create_translations_vector(a, alpha_y, beta_y)
160
+ tz = _create_translations_vector(a, 0, beta_z)
161
+
162
+ sinos = project_volume(volume, angles, -tx, -ty, -tz, cor=-cor, orig_det_dist=orig_det_dist)
163
+ data = np.moveaxis(sinos, 1, 0)
164
+
165
+ sample_motion_xy = np.stack([-tx, ty], axis=1)
166
+ sample_motion_z = -tz
167
+ angles_deg = np.degrees(angles0)
168
+ return_angles_deg = np.degrees(return_angles)
169
+ n_return_radios = len(return_angles_deg)
170
+ n_radios = data.shape[0] - n_return_radios
171
+
172
+ ureg = get_application_registry()
173
+ fake_raw_data, darks, flats = _add_fake_flats_and_dark_to_data(data, n_darks=n_darks, n_flats=n_flats)
174
+
175
+ nxtomo = NXtomo()
176
+ nxtomo.instrument.detector.data = np.concatenate(
177
+ [
178
+ darks,
179
+ flats,
180
+ fake_raw_data, # radios + return radios (in float32 !)
181
+ ]
182
+ )
183
+ image_key_control = np.concatenate(
184
+ [
185
+ [ImageKey.DARK_FIELD.value] * n_darks,
186
+ [ImageKey.FLAT_FIELD.value] * n_flats,
187
+ [ImageKey.PROJECTION.value] * n_radios,
188
+ [ImageKey.ALIGNMENT.value] * n_return_radios,
189
+ ]
190
+ )
191
+ nxtomo.instrument.detector.image_key_control = image_key_control
192
+
193
+ rotation_angle = np.concatenate(
194
+ [np.zeros(n_darks, dtype="f"), np.zeros(n_flats, dtype="f"), angles_deg, return_angles_deg]
195
+ )
196
+ nxtomo.sample.rotation_angle = rotation_angle * ureg.degree
197
+ nxtomo.instrument.detector.field_of_view = "Full"
198
+ nxtomo.instrument.detector.x_pixel_size = nxtomo.instrument.detector.y_pixel_size = 1 * ureg.micrometer
199
+ nxtomo.save(file_path=output_file_path, data_path="entry", overwrite=True)
200
+
201
+ return sample_motion_xy, sample_motion_z, cor
202
+
203
+
204
+ def _generate_nx_for_360_dataset(volume, output_file_path, n_darks=10, n_flats=21):
205
+
206
+ n_angles = 250
207
+ cor = -5.5
208
+
209
+ alpha_x = -2
210
+ beta_x = 7.0
211
+ alpha_y = -2
212
+ beta_y = 3
213
+ beta_z = 100
214
+ orig_det_dist = 0
215
+
216
+ angles = np.linspace(0, 2 * np.pi, n_angles, False)
217
+ a = np.linspace(0, 1, angles.size, endpoint=False) # theta/theta_max
218
+
219
+ tx = _create_translations_vector(a, alpha_x, beta_x)
220
+ ty = _create_translations_vector(a, alpha_y, beta_y)
221
+ tz = _create_translations_vector(a, 0, beta_z)
222
+
223
+ sinos = project_volume(volume, angles, -tx, -ty, -tz, cor=-cor, orig_det_dist=orig_det_dist)
224
+ data = np.moveaxis(sinos, 1, 0)
225
+
226
+ sample_motion_xy = np.stack([-tx, ty], axis=1)
227
+ sample_motion_z = -tz
228
+ angles_deg = np.degrees(angles)
229
+
230
+ ureg = get_application_registry()
231
+
232
+ fake_raw_data, darks, flats = _add_fake_flats_and_dark_to_data(data, n_darks=n_darks, n_flats=n_flats)
233
+
234
+ nxtomo = NXtomo()
235
+ nxtomo.instrument.detector.data = np.concatenate([darks, flats, fake_raw_data]) # in float32 !
236
+
237
+ image_key_control = np.concatenate(
238
+ [
239
+ [ImageKey.DARK_FIELD.value] * n_darks,
240
+ [ImageKey.FLAT_FIELD.value] * n_flats,
241
+ [ImageKey.PROJECTION.value] * data.shape[0],
242
+ ]
243
+ )
244
+ nxtomo.instrument.detector.image_key_control = image_key_control
245
+
246
+ rotation_angle = np.concatenate(
247
+ [
248
+ np.zeros(n_darks, dtype="f"),
249
+ np.zeros(n_flats, dtype="f"),
250
+ angles_deg,
251
+ ]
252
+ )
253
+ nxtomo.sample.rotation_angle = rotation_angle * ureg.degree
254
+ nxtomo.instrument.detector.field_of_view = "Full"
255
+ nxtomo.instrument.detector.x_pixel_size = nxtomo.instrument.detector.y_pixel_size = 1 * ureg.micrometer
256
+ nxtomo.save(file_path=output_file_path, data_path="entry", overwrite=True)
257
+
258
+ return sample_motion_xy, sample_motion_z, cor
259
+
260
+
261
+ @pytest.fixture(scope="class")
262
+ def setup_test_motion_estimator(request):
263
+ cls = request.cls
264
+ cls.volume = get_data("motion/mri_volume_subsampled.npy")
265
+
266
+
267
+ @pytest.mark.skipif(not (__do_long_tests__), reason="need environment variable NABU_LONG_TESTS=1")
268
+ @pytest.mark.usefixtures("setup_test_motion_estimator")
269
+ class TestMotionEstimator:
270
+
271
+ def _setup(self, tmpdir):
272
+ # pytest uses some weird data structure for "tmpdir"
273
+ if not (isinstance(tmpdir, str)):
274
+ tmpdir = str(tmpdir)
275
+ #
276
+ if getattr(self, "volume", None) is None:
277
+ self.volume = get_data("motion/mri_volume_subsampled.npy")
278
+
279
+ def test_estimate_motion_360_dataset(self, tmpdir, verbose=False):
280
+ self._setup(tmpdir)
281
+ nx_file_path = os.path.join(tmpdir, "mri_projected_360_motion.nx")
282
+ sample_motion_xy, sample_motion_z, cor = _generate_nx_for_360_dataset(self.volume, nx_file_path)
283
+
284
+ dataset_info = analyze_dataset(nx_file_path)
285
+
286
+ translations_estimator = TranslationsEstimator(
287
+ dataset_info, do_flatfield=True, rot_center=cor, angular_subsampling=5, deg_xy=2, deg_z=2
288
+ )
289
+ estimated_shifts_h, estimated_shifts_v, estimated_cor = translations_estimator.estimate_motion()
290
+
291
+ s = translations_estimator.angular_subsampling
292
+ if verbose:
293
+ translations_estimator.motion_estimator.plot_detector_shifts(cor=cor)
294
+ translations_estimator.motion_estimator.plot_movements(
295
+ cor=cor,
296
+ angles_rad=dataset_info.rotation_angles[::s],
297
+ gt_xy=sample_motion_xy[::s, :],
298
+ gt_z=sample_motion_z[::s],
299
+ )
300
+ check_motion_estimation(
301
+ translations_estimator.motion_estimator,
302
+ dataset_info.rotation_angles[::s],
303
+ cor,
304
+ sample_motion_xy[::s, :],
305
+ sample_motion_z[::s],
306
+ fit_error_shifts_tol_vu=(0.2, 0.2),
307
+ fit_error_det_tol_vu=(1e-5, 5e-2),
308
+ fit_error_tol_xyz=(0.05, 0.05, 0.05),
309
+ fit_error_det_all_angles_tol_vu=(1e-5, 0.05),
310
+ )
311
+
312
+ def test_estimate_motion_180_dataset(self, tmpdir, verbose=False):
313
+ self._setup(tmpdir)
314
+ nx_file_path = os.path.join(tmpdir, "mri_projected_180_motion.nx")
315
+
316
+ sample_motion_xy, sample_motion_z, cor = _generate_nx_for_180_dataset(self.volume, nx_file_path)
317
+
318
+ dataset_info = analyze_dataset(nx_file_path)
319
+
320
+ translations_estimator = TranslationsEstimator(
321
+ dataset_info,
322
+ do_flatfield=True,
323
+ rot_center=cor,
324
+ angular_subsampling=2,
325
+ deg_xy=2,
326
+ deg_z=2,
327
+ shifts_estimator="DetectorTranslationAlongBeam",
328
+ )
329
+ estimated_shifts_h, estimated_shifts_v, estimated_cor = translations_estimator.estimate_motion()
330
+
331
+ if verbose:
332
+ translations_estimator.motion_estimator.plot_detector_shifts(cor=cor)
333
+ translations_estimator.motion_estimator.plot_movements(
334
+ cor=cor,
335
+ angles_rad=dataset_info.rotation_angles,
336
+ gt_xy=sample_motion_xy[: dataset_info.n_angles],
337
+ gt_z=sample_motion_z[: dataset_info.n_angles],
338
+ )
339
+
340
+ check_motion_estimation(
341
+ translations_estimator.motion_estimator,
342
+ dataset_info.rotation_angles,
343
+ cor,
344
+ sample_motion_xy,
345
+ sample_motion_z,
346
+ fit_error_shifts_tol_vu=(0.02, 0.1),
347
+ fit_error_det_tol_vu=(1e-2, 0.5),
348
+ fit_error_tol_xyz=(0.5, 2, 1e-2),
349
+ fit_error_det_all_angles_tol_vu=(1e-2, 2),
350
+ )
351
+
352
+
353
+ if __name__ == "__main__":
354
+
355
+ T = TestMotionEstimator()
356
+ with TemporaryDirectory(suffix="_motion", prefix="nabu_testdata") as tmpdir:
357
+ T.test_estimate_motion_360_dataset(tmpdir, verbose=True)
358
+ T.test_estimate_motion_180_dataset(tmpdir, verbose=True)
nabu/preproc/ccd.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import numpy as np
2
2
  from ..utils import check_supported
3
+ from scipy.ndimage import binary_dilation
3
4
  from silx.math.medianfilter import medfilt2d
4
5
 
5
6
 
@@ -13,6 +14,7 @@ class CCDFilter:
13
14
  def __init__(
14
15
  self,
15
16
  radios_shape: tuple,
17
+ kernel_size: int = 3,
16
18
  correction_type: str = "median_clip",
17
19
  median_clip_thresh: float = 0.1,
18
20
  abs_diff=False,
@@ -26,6 +28,9 @@ class CCDFilter:
26
28
  radios_shape: tuple
27
29
  A tuple describing the shape of the radios stack, in the form
28
30
  `(n_radios, n_z, n_x)`.
31
+ kernel_size: int
32
+ Size of the kernel for the median filter.
33
+ Default is 3.
29
34
  correction_type: str
30
35
  Correction type for radios ("median_clip", "sigma_clip", ...)
31
36
  median_clip_thresh: float, optional
@@ -48,6 +53,7 @@ class CCDFilter:
48
53
  then this pixel value is set to the median value.
49
54
  """
50
55
  self._set_radios_shape(radios_shape)
56
+ self.kernel_size = kernel_size
51
57
  check_supported(correction_type, self._supported_ccd_corrections, "CCD correction mode")
52
58
  self.correction_type = correction_type
53
59
  self.median_clip_thresh = median_clip_thresh
@@ -67,11 +73,11 @@ class CCDFilter:
67
73
  self.shape = (n_z, n_x)
68
74
 
69
75
  @staticmethod
70
- def median_filter(img):
76
+ def median_filter(img, kernel_size=3):
71
77
  """
72
78
  Perform a median filtering on an image.
73
79
  """
74
- return medfilt2d(img, (3, 3), mode="reflect")
80
+ return medfilt2d(img, (kernel_size, kernel_size), mode="reflect")
75
81
 
76
82
  def median_clip_mask(self, img, return_medians=False):
77
83
  """
@@ -85,7 +91,7 @@ class CCDFilter:
85
91
  return_medians: bool, optional
86
92
  Whether to return the median values additionally to the mask.
87
93
  """
88
- median_values = self.median_filter(img)
94
+ median_values = self.median_filter(img, kernel_size=self.kernel_size)
89
95
  if not self.abs_diff:
90
96
  invalid_mask = img >= median_values + self.median_clip_thresh
91
97
  else:
@@ -124,6 +130,50 @@ class CCDFilter:
124
130
 
125
131
  return output
126
132
 
133
+ def dezinger_correction(self, radios, dark=None, nsigma=5, output=None):
134
+ """
135
+ Compute the median clip correction on a radios stack, and propagates the invalid pixels into vert and horiz directions.
136
+
137
+ Parameters
138
+ ----------
139
+ radios: numpy.ndarray
140
+ A radios stack.
141
+ dark: numpy.ndarray, optional
142
+ A dark image. Default is None. If not None, it is subtracted from the radios.
143
+ nsigma: float, optional
144
+ Number of standard deviations to use for the zinger detection.
145
+ Default is 5.
146
+ output: numpy.ndarray, optional
147
+ Output array
148
+ """
149
+ if radios.shape[1:] != self.radios_shape[1:]:
150
+ raise ValueError(f"Expected radios shape {self.radios_shape}, got {radios.shape}")
151
+
152
+ if output is None:
153
+ output = np.copy(radios)
154
+ else:
155
+ output[:] = radios[:]
156
+
157
+ n_radios = radios.shape[0]
158
+ for i in range(n_radios):
159
+ if dark is None:
160
+ dimg = radios[i]
161
+ elif dark.shape == radios.shape[1:]:
162
+ dimg = radios[i] - dark
163
+ else:
164
+ raise ValueError("Dark image shape does not match radios shape.")
165
+
166
+ dimg = radios[i] - dark
167
+ med = self.median_filter(dimg, self.kernel_size)
168
+ err = dimg - med
169
+ ds0 = err.std()
170
+ msk = err > ds0 * nsigma
171
+ gromsk = binary_dilation(msk)
172
+
173
+ output[i] = np.where(gromsk, med, radios[i])
174
+
175
+ return output
176
+
127
177
 
128
178
  class Log:
129
179
  """
nabu/preproc/flatfield.py CHANGED
@@ -1,8 +1,11 @@
1
+ import os
1
2
  from multiprocessing.pool import ThreadPool
2
3
  from bisect import bisect_left
3
4
  import numpy as np
5
+ from tomoscan.io import HDF5File
4
6
  from ..io.reader import load_images_from_dataurl_dict
5
7
  from ..utils import check_supported, deprecated_class, get_num_threads
8
+ from .ccd import CCDFilter
6
9
 
7
10
 
8
11
  class FlatFieldArrays:
@@ -102,7 +105,7 @@ class FlatFieldArrays:
102
105
  self._precompute_flats_indices_weights()
103
106
  self._configure_srcurrent_normalization(radios_srcurrent, flats_srcurrent)
104
107
  self.distortion_correction = distortion_correction
105
- self.n_threads = min(1, get_num_threads(n_threads) // 2)
108
+ self.n_threads = max(1, get_num_threads(n_threads) // 2)
106
109
 
107
110
  def _set_parameters(self, radios_shape, radios_indices, interpolation, nan_value):
108
111
  self._set_radios_shape(radios_shape)
@@ -450,3 +453,305 @@ class FlatFieldDataUrls(FlatField):
450
453
  radios_srcurrent=radios_srcurrent,
451
454
  flats_srcurrent=flats_srcurrent,
452
455
  )
456
+
457
+
458
+ class PCAFlatsNormalizer:
459
+ """This class implement a flatfield normalization based on a PCA of a series of acauired flatfields.
460
+ The PCA decomposition is handled by a PCAFlatsDecomposer object.
461
+
462
+ This implementation was proposed by Jailin C. et al in https://doi.org/10.1107/S1600577516015812.
463
+
464
+ Code initially written by ID11 @ ESRF staff.
465
+ Jonathan Wright - Implementation based on research paper
466
+ Pedro D. Resende - Added saving and loading from file capabilities
467
+
468
+ Jerome Lesaint - Integrated the solution in Nabu.
469
+
470
+ """
471
+
472
+ def __init__(self, components, dark, mean):
473
+ """Initializes all variables needed to perform the flatfield normalization.
474
+
475
+ Parameters
476
+ -----------
477
+ components: ndarray
478
+ The components of the PCA decomposition.
479
+ dark: ndarray
480
+ The dark image. Should be one single 2D image.
481
+ mean: ndarray
482
+ The mean image of the series of flats.
483
+ """
484
+ ones = np.ones_like(components[0], dtype=np.float32)[np.newaxis] # This comp will account for I0
485
+ self.components = np.concatenate([ones, components], axis=0)
486
+ self.dark = dark
487
+ self.mean = mean
488
+ self.n_threads = max(1, get_num_threads() // 2)
489
+ self._setmask()
490
+ self.ccdfilter = CCDFilter(mean.shape)
491
+
492
+ def _form_lsq_matrix(self):
493
+ """This function form the Least Square matrix, based on the flats components and the mask."""
494
+ self.Amat = np.stack([gg for gg in self.g], axis=1) # JL: this is the matrix for the fit
495
+
496
+ def _setmask(self, prop=0.125):
497
+ """Sets the mask to select where the model is going to be fitted.
498
+
499
+ Parameters
500
+ ----------
501
+ prop: float, default: 0.125
502
+ The proportion of the image width to take on each side of the image as a mask.
503
+
504
+ By default it sets the strips on each side of the frame in the form:
505
+ mask[:, lim:] = True
506
+ mask[:, -lim:] = True
507
+ Where lim = prop * flat.shape[1]
508
+
509
+ If you need a custom mask, see update_mask() method.
510
+ """
511
+
512
+ lim = int(prop * self.mean.shape[1])
513
+ self.mask = np.zeros(self.mean.shape, dtype=bool)
514
+ self.mask[:, :lim] = True
515
+ self.mask[:, -lim:] = True
516
+ self.g = []
517
+ for component in self.components:
518
+ self.g.append(component[self.mask])
519
+
520
+ self._form_lsq_matrix()
521
+
522
+ def update_mask(self, mask: np.ndarray):
523
+ """Method to update the mask with a custom mask in the form of a boolean 2+D array.
524
+
525
+ Paramters
526
+ =========
527
+ mask: np.ndarray of Boolean;
528
+ The array of boolean allows the selection of the region of the image that will be used to fit against the components.
529
+
530
+ It will set the mask, replacing the standard mask created with setmask().
531
+ """
532
+ if mask.dtype == bool:
533
+ self.mask = mask
534
+ self.g = []
535
+ for component in self.components:
536
+ self.g.append(component[self.mask])
537
+
538
+ self._form_lsq_matrix()
539
+ else:
540
+ raise TypeError("Not a boolean array. Will keep the default mask")
541
+
542
+ def normalize_radios(self, projections, mask=None, prop=0.125):
543
+ """This is to keep the flatfield API in the pipeline."""
544
+ self.correct_stack(projections, mask=mask, prop=prop)
545
+
546
+ def correct_stack(
547
+ self,
548
+ projections: np.ndarray,
549
+ mask: np.ndarray = None,
550
+ prop: float = 0.125,
551
+ ):
552
+ """This functions normalizes the stack of projections.
553
+
554
+ Performs correction on a stack of projections based on the calculated decomposition.
555
+ The normalizations is done in-place. The previous projections before normalization are lost.
556
+
557
+ Parameters
558
+ ----------
559
+ projections: ndarray
560
+ Stack of projections to normalize.
561
+
562
+ prop: float (default: {0.125})
563
+ Fraction to mask on the horizontal field of view, assuming vertical rotation axis
564
+
565
+ mask: np.ndarray (default: None)
566
+ Custom mask if your data requires it.
567
+
568
+ Returns
569
+ -------
570
+ corrected projections: np.ndarray
571
+ Flat field corrected images. Note that the returned projections are exp-transformed to fit the pipeline (e.g. to allow for phase retrieval).
572
+ """
573
+ self.projections = self.ccdfilter.dezinger_correction(projections, self.dark)
574
+
575
+ if mask is not None:
576
+ self.update_mask(mask=mask)
577
+ else:
578
+ self._setmask(prop=prop)
579
+
580
+ with ThreadPool(self.n_threads) as tp:
581
+ for i, cor, sol in tp.map(self._readcorrect1, range(len(projections))):
582
+ # solution[i] = sol
583
+ projections[i] = np.exp(-cor)
584
+
585
+ def _readcorrect1(self, ii):
586
+ """Method to allow parallelization of the normalization."""
587
+ corr, s = self.correctproj(self.projections[ii])
588
+ return ii, corr, s
589
+
590
+ def correctproj(self, projection):
591
+ """Performs the correction on one projection of the stack.
592
+
593
+ Parameters
594
+ ----------
595
+ projection: np.ndarray, float
596
+ Radiograph from the acquisition stack.
597
+
598
+ Returns
599
+ -------
600
+ The fitted projection.
601
+ """
602
+ logp = np.log(projection.astype(np.float32) - self.dark)
603
+ corr = self.mean - logp
604
+ # model to be fitted !!
605
+ return self.fit(corr)
606
+
607
+ def fit(self, corr):
608
+ """Fit the (masked) projection to the (masked) components of the PCA decomposition.
609
+
610
+ This is for each projection, so worth optimising ...
611
+ """
612
+ y = corr[self.mask]
613
+ solution = np.linalg.lstsq(self.Amat, y, rcond=None)[0]
614
+ correction = np.einsum("ijk,i->jk", self.components, solution)
615
+ return corr - correction, solution
616
+
617
+
618
+ class PCAFlatsDecomposer:
619
+ """This class implements a PCA decomposition of a serie of acquired flatfields.
620
+ The PCA decomposition is used to normalize the projections through a PCAFLatNormalizer object.
621
+
622
+ This implementation was proposed by Jailin C. et al in https://doi.org/10.1107/S1600577516015812.
623
+
624
+ Code initially written by ID11 @ ESRF staff.
625
+ Jonathan Wright - Implementation based on research paper
626
+ Pedro D. Resende - Added saving and loading from file capabilities
627
+
628
+ Jerome Lesaint - Integrated the solution in Nabu.
629
+ """
630
+
631
+ def __init__(self, flats: np.ndarray, darks: np.ndarray, nsigma=3):
632
+ """
633
+
634
+ Parameters
635
+ -----------
636
+ flats: np.ndarray
637
+ A stack of darks corrected flat field images
638
+ darks: np.ndarray
639
+ An image or stack of images of the dark current images of the camera.
640
+
641
+ Does the log scaling.
642
+ Subtracts mean and does eigenvector decomposition.
643
+ """
644
+ self.n_threads = max(1, get_num_threads() // 2)
645
+ self.flats = np.empty(flats.shape, dtype=np.float32)
646
+
647
+ darks = darks.astype(np.float32)
648
+ if darks.ndim == 3:
649
+ self.dark = np.median(darks, axis=0)
650
+ else:
651
+ self.dark = darks.copy()
652
+ del darks
653
+
654
+ self.nsigma = nsigma
655
+ self.ccdfilter = CCDFilter(self.dark.shape)
656
+ self._ccdfilter_and_log(flats) # Log is taken here (after dezinger)
657
+
658
+ self.mean = np.mean(self.flats, axis=0) # average
659
+
660
+ self.flats = self.flats - self.mean # makes a copy
661
+ self.cov = self.compute_correlation_matrix()
662
+ self.compute_pca()
663
+ self.generate_pca_flats(nsigma=nsigma) # Default nsigma=3
664
+
665
+ def __str__(self):
666
+
667
+ return f"PCA decomposition from flat images. \nThere are {self.components} components created at {self.sigma} level."
668
+
669
+ def _ccdfilter_and_log(self, flats: np.ndarray):
670
+ """Dezinger (substract dark, apply median filter) and takes log of the flat stack"""
671
+
672
+ self.ccdfilter.dezinger_correction(flats, self.dark, nsigma=self.nsigma)
673
+
674
+ with ThreadPool(self.n_threads) as tp:
675
+ for i, frame in enumerate(tp.map(np.log, flats)):
676
+ self.flats[i] = frame
677
+
678
+ @staticmethod
679
+ def ccij(args):
680
+ """Compute the covariance (img[i]*img[j]).sum() / npixels
681
+ It is a wrapper for threading.
682
+ args == i, j, npixels, imgs
683
+ """
684
+
685
+ i, j, NY, imgs = args
686
+ return i, j, np.einsum("ij,ij", imgs[i], imgs[j]) / NY
687
+
688
+ def compute_correlation_matrix(self):
689
+ """Computes an (nflats x nflats) correlation matrix"""
690
+
691
+ N = len(self.flats)
692
+ CC = np.zeros((N, N), float)
693
+ args = [(i, j, N, self.flats) for i in range(N) for j in range(i + 1)]
694
+ with ThreadPool(self.n_threads) as tp:
695
+ for i, j, result in tp.map(self.ccij, args):
696
+ CC[i, j] = CC[j, i] = result
697
+ return CC
698
+
699
+ def compute_pca(self):
700
+ """Gets eigenvectors and eigenvalues and sorts them into order"""
701
+
702
+ self.eigenvalues, self.eigenvectors = np.linalg.eigh(
703
+ self.cov
704
+ ) # Not sure why the eigh is needed. eig should be enough.
705
+ order = np.argsort(abs(self.eigenvalues))[::-1] # high to low
706
+ self.eigenvalues = self.eigenvalues[order]
707
+ self.eigenvectors = self.eigenvectors[:, order]
708
+
709
+ def generate_pca_flats(self, nsigma=3):
710
+ """Projects the eigenvectors back into image space.
711
+
712
+ Parameters
713
+ ----------
714
+ nsigma: int (default: 3)
715
+ """
716
+
717
+ self.sigma = nsigma
718
+ av = abs(self.eigenvalues)
719
+ N = (av > (av[-2] * nsigma)).sum() # Go for 3 sigma
720
+ self.components = [
721
+ None,
722
+ ] * N
723
+
724
+ def calculate(ii):
725
+ calc = np.einsum("i,ijk->jk", self.eigenvectors[:, ii], self.flats)
726
+ norm = (calc**2).sum()
727
+ return ii, calc / np.sqrt(norm)
728
+
729
+ with ThreadPool(self.n_threads) as tp:
730
+ for ii, result in tp.map(calculate, range(N)):
731
+ self.components[ii] = result
732
+
733
+ # simple gradients
734
+ r, c = self.components[0].shape
735
+ self.components.append(np.outer(np.ones(r), np.linspace(-1 / c, 1 / c, c)))
736
+ self.components.append(np.outer(np.linspace(-1 / r, 1 / r, r), np.ones(c)))
737
+
738
+ def save_decomposition(self, path="PCA_flats.h5", overwrite=True, entry="entry0000"):
739
+ """Saves the basic information of a PCA decomposition in view of the normalization of projections.
740
+
741
+ Parameters
742
+ ----------
743
+ path: str (default: "PCA_flats.h5")
744
+ Full path to the h5 file you want to save your results. It will overwrite!! Be careful.
745
+ """
746
+
747
+ file_exists = os.path.exists(path)
748
+ if overwrite or not file_exists:
749
+ with HDF5File(path, "w") as hout:
750
+ group = hout.create_group(entry)
751
+ group["eigenvalues"] = self.eigenvalues
752
+ group["dark"] = self.dark.astype(np.float32)
753
+ group["p_mean"] = self.mean.astype(np.float32)
754
+ group["p_components"] = np.array(self.components)
755
+ hout.flush()
756
+ else:
757
+ raise OSError(f"The file {path} already exists and you chose to NOT overwrite.")