nabu 2024.2.13__py3-none-any.whl → 2025.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. doc/doc_config.py +32 -0
  2. nabu/__init__.py +1 -1
  3. nabu/app/bootstrap_stitching.py +4 -2
  4. nabu/app/cast_volume.py +16 -14
  5. nabu/app/cli_configs.py +102 -9
  6. nabu/app/compare_volumes.py +1 -1
  7. nabu/app/composite_cor.py +2 -4
  8. nabu/app/diag_to_pix.py +5 -6
  9. nabu/app/diag_to_rot.py +10 -11
  10. nabu/app/double_flatfield.py +18 -5
  11. nabu/app/estimate_motion.py +75 -0
  12. nabu/app/multicor.py +28 -15
  13. nabu/app/parse_reconstruction_log.py +1 -0
  14. nabu/app/pcaflats.py +122 -0
  15. nabu/app/prepare_weights_double.py +1 -2
  16. nabu/app/reconstruct.py +1 -7
  17. nabu/app/reconstruct_helical.py +5 -9
  18. nabu/app/reduce_dark_flat.py +5 -4
  19. nabu/app/rotate.py +3 -1
  20. nabu/app/stitching.py +7 -2
  21. nabu/app/tests/test_reduce_dark_flat.py +2 -2
  22. nabu/app/validator.py +1 -4
  23. nabu/cuda/convolution.py +1 -1
  24. nabu/cuda/fft.py +1 -1
  25. nabu/cuda/medfilt.py +1 -1
  26. nabu/cuda/padding.py +1 -1
  27. nabu/cuda/src/backproj.cu +6 -6
  28. nabu/cuda/src/cone.cu +4 -0
  29. nabu/cuda/src/hierarchical_backproj.cu +14 -0
  30. nabu/cuda/utils.py +2 -2
  31. nabu/estimation/alignment.py +17 -31
  32. nabu/estimation/cor.py +27 -33
  33. nabu/estimation/cor_sino.py +2 -8
  34. nabu/estimation/focus.py +4 -8
  35. nabu/estimation/motion.py +557 -0
  36. nabu/estimation/tests/test_alignment.py +2 -0
  37. nabu/estimation/tests/test_motion_estimation.py +471 -0
  38. nabu/estimation/tests/test_tilt.py +1 -1
  39. nabu/estimation/tilt.py +6 -5
  40. nabu/estimation/translation.py +47 -1
  41. nabu/io/cast_volume.py +108 -18
  42. nabu/io/detector_distortion.py +5 -6
  43. nabu/io/reader.py +45 -6
  44. nabu/io/reader_helical.py +5 -4
  45. nabu/io/tests/test_cast_volume.py +2 -2
  46. nabu/io/tests/test_readers.py +41 -38
  47. nabu/io/tests/test_remove_volume.py +152 -0
  48. nabu/io/tests/test_writers.py +2 -2
  49. nabu/io/utils.py +8 -4
  50. nabu/io/writer.py +1 -2
  51. nabu/misc/fftshift.py +1 -1
  52. nabu/misc/fourier_filters.py +1 -1
  53. nabu/misc/histogram.py +1 -1
  54. nabu/misc/histogram_cuda.py +1 -1
  55. nabu/misc/padding_base.py +1 -1
  56. nabu/misc/rotation.py +1 -1
  57. nabu/misc/rotation_cuda.py +1 -1
  58. nabu/misc/tests/test_binning.py +1 -1
  59. nabu/misc/transpose.py +1 -1
  60. nabu/misc/unsharp.py +1 -1
  61. nabu/misc/unsharp_cuda.py +1 -1
  62. nabu/misc/unsharp_opencl.py +1 -1
  63. nabu/misc/utils.py +1 -1
  64. nabu/opencl/fft.py +1 -1
  65. nabu/opencl/padding.py +1 -1
  66. nabu/opencl/src/backproj.cl +6 -6
  67. nabu/opencl/utils.py +8 -8
  68. nabu/pipeline/config.py +2 -2
  69. nabu/pipeline/config_validators.py +46 -46
  70. nabu/pipeline/datadump.py +3 -3
  71. nabu/pipeline/estimators.py +271 -11
  72. nabu/pipeline/fullfield/chunked.py +103 -67
  73. nabu/pipeline/fullfield/chunked_cuda.py +5 -2
  74. nabu/pipeline/fullfield/computations.py +4 -1
  75. nabu/pipeline/fullfield/dataset_validator.py +0 -1
  76. nabu/pipeline/fullfield/get_double_flatfield.py +147 -0
  77. nabu/pipeline/fullfield/nabu_config.py +36 -17
  78. nabu/pipeline/fullfield/processconfig.py +41 -7
  79. nabu/pipeline/fullfield/reconstruction.py +14 -10
  80. nabu/pipeline/helical/dataset_validator.py +3 -4
  81. nabu/pipeline/helical/fbp.py +4 -4
  82. nabu/pipeline/helical/filtering.py +5 -4
  83. nabu/pipeline/helical/gridded_accumulator.py +10 -11
  84. nabu/pipeline/helical/helical_chunked_regridded.py +1 -0
  85. nabu/pipeline/helical/helical_reconstruction.py +12 -9
  86. nabu/pipeline/helical/helical_utils.py +1 -2
  87. nabu/pipeline/helical/nabu_config.py +2 -1
  88. nabu/pipeline/helical/span_strategy.py +1 -0
  89. nabu/pipeline/helical/weight_balancer.py +2 -3
  90. nabu/pipeline/params.py +20 -3
  91. nabu/pipeline/tests/__init__.py +0 -0
  92. nabu/pipeline/tests/test_estimators.py +240 -3
  93. nabu/pipeline/utils.py +1 -1
  94. nabu/pipeline/writer.py +1 -1
  95. nabu/preproc/alignment.py +0 -10
  96. nabu/preproc/ccd.py +53 -3
  97. nabu/preproc/ctf.py +8 -8
  98. nabu/preproc/ctf_cuda.py +1 -1
  99. nabu/preproc/double_flatfield_cuda.py +2 -2
  100. nabu/preproc/double_flatfield_variable_region.py +0 -1
  101. nabu/preproc/flatfield.py +307 -2
  102. nabu/preproc/flatfield_cuda.py +1 -2
  103. nabu/preproc/flatfield_variable_region.py +3 -3
  104. nabu/preproc/phase.py +2 -4
  105. nabu/preproc/phase_cuda.py +2 -2
  106. nabu/preproc/shift.py +4 -2
  107. nabu/preproc/shift_cuda.py +0 -1
  108. nabu/preproc/tests/test_ctf.py +4 -4
  109. nabu/preproc/tests/test_double_flatfield.py +1 -1
  110. nabu/preproc/tests/test_flatfield.py +1 -1
  111. nabu/preproc/tests/test_paganin.py +1 -3
  112. nabu/preproc/tests/test_pcaflats.py +154 -0
  113. nabu/preproc/tests/test_vshift.py +4 -1
  114. nabu/processing/azim.py +9 -5
  115. nabu/processing/convolution_cuda.py +6 -4
  116. nabu/processing/fft_base.py +7 -3
  117. nabu/processing/fft_cuda.py +25 -164
  118. nabu/processing/fft_opencl.py +28 -6
  119. nabu/processing/fftshift.py +1 -1
  120. nabu/processing/histogram.py +1 -1
  121. nabu/processing/muladd.py +0 -1
  122. nabu/processing/padding_base.py +1 -1
  123. nabu/processing/padding_cuda.py +0 -2
  124. nabu/processing/processing_base.py +12 -6
  125. nabu/processing/rotation_cuda.py +3 -1
  126. nabu/processing/tests/test_fft.py +2 -64
  127. nabu/processing/tests/test_fftshift.py +1 -1
  128. nabu/processing/tests/test_medfilt.py +1 -3
  129. nabu/processing/tests/test_padding.py +1 -1
  130. nabu/processing/tests/test_roll.py +1 -1
  131. nabu/processing/tests/test_rotation.py +4 -2
  132. nabu/processing/unsharp_opencl.py +1 -1
  133. nabu/reconstruction/astra.py +245 -0
  134. nabu/reconstruction/cone.py +39 -9
  135. nabu/reconstruction/fbp.py +14 -0
  136. nabu/reconstruction/fbp_base.py +40 -8
  137. nabu/reconstruction/fbp_opencl.py +8 -0
  138. nabu/reconstruction/filtering.py +59 -25
  139. nabu/reconstruction/filtering_cuda.py +22 -21
  140. nabu/reconstruction/filtering_opencl.py +10 -14
  141. nabu/reconstruction/hbp.py +26 -13
  142. nabu/reconstruction/mlem.py +55 -16
  143. nabu/reconstruction/projection.py +3 -5
  144. nabu/reconstruction/sinogram.py +1 -1
  145. nabu/reconstruction/sinogram_cuda.py +0 -1
  146. nabu/reconstruction/tests/test_cone.py +37 -2
  147. nabu/reconstruction/tests/test_deringer.py +4 -4
  148. nabu/reconstruction/tests/test_fbp.py +36 -15
  149. nabu/reconstruction/tests/test_filtering.py +27 -7
  150. nabu/reconstruction/tests/test_halftomo.py +28 -2
  151. nabu/reconstruction/tests/test_mlem.py +94 -64
  152. nabu/reconstruction/tests/test_projector.py +7 -2
  153. nabu/reconstruction/tests/test_reconstructor.py +1 -1
  154. nabu/reconstruction/tests/test_sino_normalization.py +0 -1
  155. nabu/resources/dataset_analyzer.py +210 -24
  156. nabu/resources/gpu.py +4 -4
  157. nabu/resources/logger.py +4 -4
  158. nabu/resources/nxflatfield.py +103 -37
  159. nabu/resources/tests/test_dataset_analyzer.py +37 -0
  160. nabu/resources/tests/test_extract.py +11 -0
  161. nabu/resources/tests/test_nxflatfield.py +5 -5
  162. nabu/resources/utils.py +16 -10
  163. nabu/stitching/alignment.py +8 -11
  164. nabu/stitching/config.py +44 -35
  165. nabu/stitching/definitions.py +2 -2
  166. nabu/stitching/frame_composition.py +8 -10
  167. nabu/stitching/overlap.py +4 -4
  168. nabu/stitching/sample_normalization.py +5 -5
  169. nabu/stitching/slurm_utils.py +2 -2
  170. nabu/stitching/stitcher/base.py +2 -0
  171. nabu/stitching/stitcher/dumper/base.py +0 -1
  172. nabu/stitching/stitcher/dumper/postprocessing.py +1 -1
  173. nabu/stitching/stitcher/post_processing.py +11 -9
  174. nabu/stitching/stitcher/pre_processing.py +37 -31
  175. nabu/stitching/stitcher/single_axis.py +2 -3
  176. nabu/stitching/stitcher_2D.py +2 -1
  177. nabu/stitching/tests/test_config.py +10 -11
  178. nabu/stitching/tests/test_sample_normalization.py +1 -1
  179. nabu/stitching/tests/test_slurm_utils.py +1 -2
  180. nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
  181. nabu/stitching/tests/test_z_postprocessing_stitching.py +3 -3
  182. nabu/stitching/tests/test_z_preprocessing_stitching.py +27 -24
  183. nabu/stitching/utils/tests/__init__.py +0 -0
  184. nabu/stitching/utils/tests/test_post-processing.py +1 -0
  185. nabu/stitching/utils/utils.py +16 -18
  186. nabu/tests.py +0 -3
  187. nabu/testutils.py +62 -9
  188. nabu/utils.py +50 -20
  189. {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/METADATA +7 -7
  190. nabu-2025.1.0.dist-info/RECORD +328 -0
  191. {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/WHEEL +1 -1
  192. {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/entry_points.txt +2 -1
  193. nabu/app/correct_rot.py +0 -70
  194. nabu/io/tests/test_detector_distortion.py +0 -178
  195. nabu-2024.2.13.dist-info/RECORD +0 -317
  196. /nabu/{stitching → app}/tests/__init__.py +0 -0
  197. {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/licenses/LICENSE +0 -0
  198. {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/top_level.txt +0 -0
nabu/preproc/ctf.py CHANGED
@@ -67,7 +67,7 @@ class GeoPars:
67
67
  else:
68
68
  self.M_vh = np.array([1, 1])
69
69
 
70
- self.logger.debug("Magnification : h ({}) ; v ({}) ".format(self.M_vh[1], self.M_vh[0]))
70
+ self.logger.debug(f"Magnification : h ({self.M_vh[1]}) ; v ({self.M_vh[0]}) ")
71
71
 
72
72
  self.length_scale = length_scale
73
73
  self.wavelength = wavelength
@@ -80,17 +80,17 @@ class GeoPars:
80
80
 
81
81
  which_unit = int(np.sum(np.array([self.pix_size_rec > small for small in [1.0e-6, 1.0e-7]]).astype(np.int32)))
82
82
  self.pixelsize_string = [
83
- "{:.1f} nm".format(self.pix_size_rec * 1e9),
84
- "{:.3f} um".format(self.pix_size_rec * 1e6),
85
- "{:.1f} um".format(self.pix_size_rec * 1e6),
83
+ f"{self.pix_size_rec * 1e9:.1f} nm",
84
+ f"{self.pix_size_rec * 1e6:.3f} um",
85
+ f"{self.pix_size_rec * 1e6:.1f} um",
86
86
  ][which_unit]
87
87
 
88
88
  if self.magnification:
89
89
  self.logger.debug(
90
- "All images are resampled to smallest pixelsize: {}".format(self.pixelsize_string),
90
+ f"All images are resampled to smallest pixelsize: {self.pixelsize_string}",
91
91
  )
92
92
  else:
93
- self.logger.debug("Pixelsize images: {}".format(self.pixelsize_string))
93
+ self.logger.debug(f"Pixelsize images: {self.pixelsize_string}")
94
94
 
95
95
 
96
96
  class CTFPhaseRetrieval:
@@ -157,7 +157,7 @@ class CTFPhaseRetrieval:
157
157
  """
158
158
  self.logger = LoggerOrPrint(logger)
159
159
  if not isinstance(geo_pars, GeoPars):
160
- raise ValueError("Expected GeoPars instance for 'geo_pars' parameter")
160
+ raise TypeError("Expected GeoPars instance for 'geo_pars' parameter")
161
161
  self.geo_pars = geo_pars
162
162
  self._calc_shape(shape, padded_shape, padding_mode)
163
163
  self.delta_beta = delta_beta
@@ -292,7 +292,7 @@ class CTFPhaseRetrieval:
292
292
  self.cut_v = math.sqrt(1.0 / 2 / lambda_dist_vh[0]) / fsample_vh[0]
293
293
  self.cut_v = min(self.cut_v, 0.5)
294
294
 
295
- self.logger.debug("Normalized cut-off = {:5.3f}".format(self.cut_v))
295
+ self.logger.debug(f"Normalized cut-off = {self.cut_v:5.3f}")
296
296
 
297
297
  self.r = fourier_filters.get_lowpass_filter(
298
298
  padded_img_shape,
nabu/preproc/ctf_cuda.py CHANGED
@@ -48,7 +48,7 @@ class CudaCTFPhaseRetrieval(CTFPhaseRetrieval):
48
48
  padding_mode: str
49
49
  Padding mode. Default is "reflect".
50
50
 
51
- Other parameters
51
+ Other Parameters
52
52
  -----------------
53
53
  Please refer to CTFPhaseRetrieval documentation.
54
54
  """
@@ -7,7 +7,7 @@ from .ccd_cuda import CudaLog
7
7
 
8
8
  if __has_pycuda__:
9
9
  import pycuda.gpuarray as garray
10
- import pycuda.cumath as cumath
10
+ from pycuda import cumath
11
11
 
12
12
 
13
13
  class CudaDoubleFlatField(DoubleFlatField):
@@ -109,7 +109,7 @@ class CudaDoubleFlatField(DoubleFlatField):
109
109
  Whether to recompute the double flatfield if already computed.
110
110
  """
111
111
  if not (isinstance(radios, garray.GPUArray)): # pylint: disable=E0606
112
- raise ValueError("Expected pycuda.gpuarray.GPUArray for radios")
112
+ raise TypeError("Expected pycuda.gpuarray.GPUArray for radios")
113
113
  if self._computed and not (recompute):
114
114
  return self.doubleflatfield
115
115
  acc = garray.zeros(radios[0].shape, "f")
@@ -1,5 +1,4 @@
1
1
  from .double_flatfield import (
2
- DoubleFlatField,
3
2
  DoubleFlatField,
4
3
  check_shape,
5
4
  get_2D_3D_shape,
nabu/preproc/flatfield.py CHANGED
@@ -1,8 +1,11 @@
1
+ import os
1
2
  from multiprocessing.pool import ThreadPool
2
3
  from bisect import bisect_left
3
4
  import numpy as np
5
+ from tomoscan.io import HDF5File
4
6
  from ..io.reader import load_images_from_dataurl_dict
5
7
  from ..utils import check_supported, deprecated_class, get_num_threads
8
+ from .ccd import CCDFilter
6
9
 
7
10
 
8
11
  class FlatFieldArrays:
@@ -91,7 +94,7 @@ class FlatFieldArrays:
91
94
  flatfield_normalization(X') = (X' - D)/(F' - D) = (X - D) / (F - D) * sF/sX
92
95
  So current normalization boils down to a scalar multiplication after flat-field.
93
96
  """
94
- if self._full_shape:
97
+ if self._full_shape: # noqa: SIM102
95
98
  # this is never going to happen in this base class. But in the derived class for helical
96
99
  # which needs to keep the full shape
97
100
  if radios_indices is not None:
@@ -102,7 +105,7 @@ class FlatFieldArrays:
102
105
  self._precompute_flats_indices_weights()
103
106
  self._configure_srcurrent_normalization(radios_srcurrent, flats_srcurrent)
104
107
  self.distortion_correction = distortion_correction
105
- self.n_threads = min(1, get_num_threads(n_threads) // 2)
108
+ self.n_threads = max(1, get_num_threads(n_threads) // 2)
106
109
 
107
110
  def _set_parameters(self, radios_shape, radios_indices, interpolation, nan_value):
108
111
  self._set_radios_shape(radios_shape)
@@ -450,3 +453,305 @@ class FlatFieldDataUrls(FlatField):
450
453
  radios_srcurrent=radios_srcurrent,
451
454
  flats_srcurrent=flats_srcurrent,
452
455
  )
456
+
457
+
458
+ class PCAFlatsNormalizer:
459
+ """This class implement a flatfield normalization based on a PCA of a series of acquired flatfields.
460
+ The PCA decomposition is handled by a PCAFlatsDecomposer object.
461
+
462
+ This implementation was proposed by Jailin C. et al in https://journals.iucr.org/s/issues/2017/01/00/fv5055/
463
+
464
+ Code initially written by ID11 @ ESRF staff.
465
+ Jonathan Wright - Implementation based on research paper
466
+ Pedro D. Resende - Added saving and loading from file capabilities
467
+
468
+ Jerome Lesaint - Integrated the solution in Nabu.
469
+
470
+ """
471
+
472
+ def __init__(self, components, dark, mean):
473
+ """Initializes all variables needed to perform the flatfield normalization.
474
+
475
+ Parameters
476
+ -----------
477
+ components: ndarray
478
+ The components of the PCA decomposition.
479
+ dark: ndarray
480
+ The dark image. Should be one single 2D image.
481
+ mean: ndarray
482
+ The mean image of the series of flats.
483
+ """
484
+ ones = np.ones_like(components[0], dtype=np.float32)[np.newaxis] # This comp will account for I0
485
+ self.components = np.concatenate([ones, components], axis=0)
486
+ self.dark = dark
487
+ self.mean = mean
488
+ self.n_threads = max(1, get_num_threads() // 2)
489
+ self._setmask()
490
+ self.ccdfilter = CCDFilter(mean.shape)
491
+
492
+ def _form_lsq_matrix(self):
493
+ """This function form the Least Square matrix, based on the flats components and the mask."""
494
+ self.Amat = np.stack([gg for gg in self.g], axis=1) # JL: this is the matrix for the fit
495
+
496
+ def _setmask(self, prop=0.125):
497
+ """Sets the mask to select where the model is going to be fitted.
498
+
499
+ Parameters
500
+ ----------
501
+ prop: float, default: 0.125
502
+ The proportion of the image width to take on each side of the image as a mask.
503
+
504
+ By default it sets the strips on each side of the frame in the form:
505
+ mask[:, lim:] = True
506
+ mask[:, -lim:] = True
507
+ Where lim = prop * flat.shape[1]
508
+
509
+ If you need a custom mask, see update_mask() method.
510
+ """
511
+
512
+ lim = int(prop * self.mean.shape[1])
513
+ self.mask = np.zeros(self.mean.shape, dtype=bool)
514
+ self.mask[:, :lim] = True
515
+ self.mask[:, -lim:] = True
516
+ self.g = []
517
+ for component in self.components:
518
+ self.g.append(component[self.mask])
519
+
520
+ self._form_lsq_matrix()
521
+
522
+ def update_mask(self, mask: np.ndarray):
523
+ """Method to update the mask with a custom mask in the form of a boolean 2+D array.
524
+
525
+ Paramters
526
+ =========
527
+ mask: np.ndarray of Boolean;
528
+ The array of boolean allows the selection of the region of the image that will be used to fit against the components.
529
+
530
+ It will set the mask, replacing the standard mask created with setmask().
531
+ """
532
+ if mask.dtype == bool:
533
+ self.mask = mask
534
+ self.g = []
535
+ for component in self.components:
536
+ self.g.append(component[self.mask])
537
+
538
+ self._form_lsq_matrix()
539
+ else:
540
+ raise TypeError("Not a boolean array. Will keep the default mask")
541
+
542
+ def normalize_radios(self, projections, mask=None, prop=0.125):
543
+ """This is to keep the flatfield API in the pipeline."""
544
+ self.correct_stack(projections, mask=mask, prop=prop)
545
+
546
+ def correct_stack(
547
+ self,
548
+ projections: np.ndarray,
549
+ mask: np.ndarray = None,
550
+ prop: float = 0.125,
551
+ ):
552
+ """This functions normalizes the stack of projections.
553
+
554
+ Performs correction on a stack of projections based on the calculated decomposition.
555
+ The normalizations is done in-place. The previous projections before normalization are lost.
556
+
557
+ Parameters
558
+ ----------
559
+ projections: ndarray
560
+ Stack of projections to normalize.
561
+
562
+ prop: float (default: {0.125})
563
+ Fraction to mask on the horizontal field of view, assuming vertical rotation axis
564
+
565
+ mask: np.ndarray (default: None)
566
+ Custom mask if your data requires it.
567
+
568
+ Returns
569
+ -------
570
+ corrected projections: np.ndarray
571
+ Flat field corrected images. Note that the returned projections are exp-transformed to fit the pipeline (e.g. to allow for phase retrieval).
572
+ """
573
+ self.projections = self.ccdfilter.dezinger_correction(projections, self.dark)
574
+
575
+ if mask is not None:
576
+ self.update_mask(mask=mask)
577
+ else:
578
+ self._setmask(prop=prop)
579
+
580
+ with ThreadPool(self.n_threads) as tp:
581
+ for i, cor, sol in tp.map(self._readcorrect1, range(len(projections))):
582
+ # solution[i] = sol
583
+ projections[i] = np.exp(-cor)
584
+
585
+ def _readcorrect1(self, ii):
586
+ """Method to allow parallelization of the normalization."""
587
+ corr, s = self.correctproj(self.projections[ii])
588
+ return ii, corr, s
589
+
590
+ def correctproj(self, projection):
591
+ """Performs the correction on one projection of the stack.
592
+
593
+ Parameters
594
+ ----------
595
+ projection: np.ndarray, float
596
+ Radiograph from the acquisition stack.
597
+
598
+ Returns
599
+ -------
600
+ The fitted projection.
601
+ """
602
+ logp = np.log(projection.astype(np.float32) - self.dark)
603
+ corr = self.mean - logp
604
+ # model to be fitted !!
605
+ return self.fit(corr)
606
+
607
+ def fit(self, corr):
608
+ """Fit the (masked) projection to the (masked) components of the PCA decomposition.
609
+
610
+ This is for each projection, so worth optimising ...
611
+ """
612
+ y = corr[self.mask]
613
+ solution = np.linalg.lstsq(self.Amat, y, rcond=None)[0]
614
+ correction = np.einsum("ijk,i->jk", self.components, solution)
615
+ return corr - correction, solution
616
+
617
+
618
+ class PCAFlatsDecomposer:
619
+ """This class implements a PCA decomposition of a serie of acquired flatfields.
620
+ The PCA decomposition is used to normalize the projections through a PCAFLatNormalizer object.
621
+
622
+ This implementation was proposed by Jailin C. et al in https://journals.iucr.org/s/issues/2017/01/00/fv5055/
623
+
624
+ Code initially written by ID11 @ ESRF staff.
625
+ Jonathan Wright - Implementation based on research paper
626
+ Pedro D. Resende - Added saving and loading from file capabilities
627
+
628
+ Jerome Lesaint - Integrated the solution in Nabu.
629
+ """
630
+
631
+ def __init__(self, flats: np.ndarray, darks: np.ndarray, nsigma=3):
632
+ """
633
+
634
+ Parameters
635
+ -----------
636
+ flats: np.ndarray
637
+ A stack of darks corrected flat field images
638
+ darks: np.ndarray
639
+ An image or stack of images of the dark current images of the camera.
640
+
641
+ Does the log scaling.
642
+ Subtracts mean and does eigenvector decomposition.
643
+ """
644
+ self.n_threads = max(1, get_num_threads() // 2)
645
+ self.flats = np.empty(flats.shape, dtype=np.float32)
646
+
647
+ darks = darks.astype(np.float32)
648
+ if darks.ndim == 3:
649
+ self.dark = np.median(darks, axis=0)
650
+ else:
651
+ self.dark = darks.copy()
652
+ del darks
653
+
654
+ self.nsigma = nsigma
655
+ self.ccdfilter = CCDFilter(self.dark.shape)
656
+ self._ccdfilter_and_log(flats) # Log is taken here (after dezinger)
657
+
658
+ self.mean = np.mean(self.flats, axis=0) # average
659
+
660
+ self.flats = self.flats - self.mean # makes a copy
661
+ self.cov = self.compute_correlation_matrix()
662
+ self.compute_pca()
663
+ self.generate_pca_flats(nsigma=nsigma) # Default nsigma=3
664
+
665
+ def __str__(self):
666
+
667
+ return f"PCA decomposition from flat images. \nThere are {self.components} components created at {self.sigma} level."
668
+
669
+ def _ccdfilter_and_log(self, flats: np.ndarray):
670
+ """Dezinger (substract dark, apply median filter) and takes log of the flat stack"""
671
+
672
+ self.ccdfilter.dezinger_correction(flats, self.dark)
673
+
674
+ with ThreadPool(self.n_threads) as tp:
675
+ for i, frame in enumerate(tp.map(np.log, flats)):
676
+ self.flats[i] = frame
677
+
678
+ @staticmethod
679
+ def ccij(args):
680
+ """Compute the covariance (img[i]*img[j]).sum() / npixels
681
+ It is a wrapper for threading.
682
+ args == i, j, npixels, imgs
683
+ """
684
+
685
+ i, j, NY, imgs = args
686
+ return i, j, np.einsum("ij,ij", imgs[i], imgs[j]) / NY
687
+
688
+ def compute_correlation_matrix(self):
689
+ """Computes an (nflats x nflats) correlation matrix"""
690
+
691
+ N = len(self.flats)
692
+ CC = np.zeros((N, N), float)
693
+ args = [(i, j, N, self.flats) for i in range(N) for j in range(i + 1)]
694
+ with ThreadPool(self.n_threads) as tp:
695
+ for i, j, result in tp.map(self.ccij, args):
696
+ CC[i, j] = CC[j, i] = result
697
+ return CC
698
+
699
+ def compute_pca(self):
700
+ """Gets eigenvectors and eigenvalues and sorts them into order"""
701
+
702
+ self.eigenvalues, self.eigenvectors = np.linalg.eigh(
703
+ self.cov
704
+ ) # Not sure why the eigh is needed. eig should be enough.
705
+ order = np.argsort(abs(self.eigenvalues))[::-1] # high to low
706
+ self.eigenvalues = self.eigenvalues[order]
707
+ self.eigenvectors = self.eigenvectors[:, order]
708
+
709
+ def generate_pca_flats(self, nsigma=3):
710
+ """Projects the eigenvectors back into image space.
711
+
712
+ Parameters
713
+ ----------
714
+ nsigma: int (default: 3)
715
+ """
716
+
717
+ self.sigma = nsigma
718
+ av = abs(self.eigenvalues)
719
+ N = (av > (av[-2] * nsigma)).sum() # Go for 3 sigma
720
+ self.components = [
721
+ None,
722
+ ] * N
723
+
724
+ def calculate(ii):
725
+ calc = np.einsum("i,ijk->jk", self.eigenvectors[:, ii], self.flats)
726
+ norm = (calc**2).sum()
727
+ return ii, calc / np.sqrt(norm)
728
+
729
+ with ThreadPool(self.n_threads) as tp:
730
+ for ii, result in tp.map(calculate, range(N)):
731
+ self.components[ii] = result
732
+
733
+ # simple gradients
734
+ r, c = self.components[0].shape
735
+ self.components.append(np.outer(np.ones(r), np.linspace(-1 / c, 1 / c, c)))
736
+ self.components.append(np.outer(np.linspace(-1 / r, 1 / r, r), np.ones(c)))
737
+
738
+ def save_decomposition(self, path="PCA_flats.h5", overwrite=True, entry="entry0000"):
739
+ """Saves the basic information of a PCA decomposition in view of the normalization of projections.
740
+
741
+ Parameters
742
+ ----------
743
+ path: str (default: "PCA_flats.h5")
744
+ Full path to the h5 file you want to save your results. It will overwrite!! Be careful.
745
+ """
746
+
747
+ file_exists = os.path.exists(path)
748
+ if overwrite or not file_exists:
749
+ with HDF5File(path, "w") as hout:
750
+ group = hout.create_group(entry)
751
+ group["eigenvalues"] = self.eigenvalues
752
+ group["dark"] = self.dark.astype(np.float32)
753
+ group["p_mean"] = self.mean.astype(np.float32)
754
+ group["p_components"] = np.array(self.components)
755
+ hout.flush()
756
+ else:
757
+ raise OSError(f"The file {path} already exists and you chose to NOT overwrite.")
@@ -4,7 +4,6 @@ from nabu.cuda.processing import CudaProcessing
4
4
  from ..preproc.flatfield import FlatFieldArrays
5
5
  from ..utils import deprecated_class, get_cuda_srcfile
6
6
  from ..io.reader import load_images_from_dataurl_dict
7
- from ..cuda.utils import __has_pycuda__
8
7
 
9
8
 
10
9
  class CudaFlatFieldArrays(FlatFieldArrays):
@@ -90,7 +89,7 @@ class CudaFlatFieldArrays(FlatFieldArrays):
90
89
  Radios chunk.
91
90
  """
92
91
  if not (isinstance(radios, self.cuda_processing.array_class)):
93
- raise ValueError("Expected a pycuda.gpuarray (got %s)" % str(type(radios)))
92
+ raise TypeError("Expected a pycuda.gpuarray (got %s)" % str(type(radios)))
94
93
  if radios.dtype != np.float32:
95
94
  raise ValueError("radios must be in float32 dtype (got %s)" % str(radios.dtype))
96
95
  if radios.shape != self.radios_shape:
@@ -1,5 +1,5 @@
1
1
  import numpy as np
2
- from .flatfield import FlatFieldArrays, load_images_from_dataurl_dict, check_supported
2
+ from .flatfield import FlatFieldArrays, load_images_from_dataurl_dict
3
3
 
4
4
 
5
5
  class FlatFieldArraysVariableRegion(FlatFieldArrays):
@@ -27,9 +27,9 @@ class FlatFieldArraysVariableRegion(FlatFieldArrays):
27
27
  does not correspond to the length of sub_indexes which is {len(sub_indexes)}
28
28
  """
29
29
  raise ValueError(message)
30
- do_flats_distortion_correction = self.distortion_correction is not None
30
+ # do_flats_distortion_correction = self.distortion_correction is not None
31
31
 
32
- whole_dark = self.get_dark()
32
+ # whole_dark = self.get_dark()
33
33
  for i, (idx, sub_r) in enumerate(zip(sub_indexes, sub_regions_per_radio)):
34
34
  start_x, end_x, start_y, end_y = sub_r
35
35
  slice_x = slice(start_x, end_x)
nabu/preproc/phase.py CHANGED
@@ -5,13 +5,12 @@ from scipy.fft import rfft2, irfft2, fft2, ifft2
5
5
  from ..utils import generate_powers, get_decay, check_supported, get_num_threads, deprecation_warning
6
6
 
7
7
  # COMPAT.
8
- from .ctf import CTFPhaseRetrieval
9
8
 
10
9
  #
11
10
 
12
11
 
13
12
  def lmicron_to_db(Lmicron, energy, distance):
14
- """
13
+ r"""
15
14
  Utility to convert the "Lmicron" parameter of PyHST
16
15
  to a value of delta/beta.
17
16
 
@@ -55,7 +54,7 @@ class PaganinPhaseRetrieval:
55
54
  fftw_num_threads=None,
56
55
  fft_num_threads=None,
57
56
  ):
58
- """
57
+ r"""
59
58
  Paganin Phase Retrieval for an infinitely distant point source.
60
59
  Formula (10) in [1].
61
60
 
@@ -98,7 +97,6 @@ class PaganinPhaseRetrieval:
98
97
 
99
98
  Notes
100
99
  ------
101
-
102
100
  **Padding methods**
103
101
 
104
102
  The phase retrieval is a convolution done in Fourier domain using FFT,
@@ -107,13 +107,13 @@ class CudaPaganinPhaseRetrieval(PaganinPhaseRetrieval):
107
107
  assert data.dtype == np.float32
108
108
  # Rectangular memcopy
109
109
  # TODO profile, and if needed include this copy in the padding kernel
110
- if isinstance(data, np.ndarray) or isinstance(data, self.cuda_processing.array_class):
110
+ if isinstance(data, np.ndarray) or isinstance(data, self.cuda_processing.array_class): # noqa: SIM101
111
111
  self.d_radio_padded[: self.shape[0], : self.shape[1]] = data[:, :]
112
112
  elif isinstance(data, cuda.DeviceAllocation):
113
113
  # TODO manual memcpy2D
114
114
  raise NotImplementedError("pycuda buffers are not supported yet")
115
115
  else:
116
- raise ValueError("Expected either numpy array, pycuda array or pycuda buffer")
116
+ raise TypeError("Expected either numpy array, pycuda array or pycuda buffer")
117
117
 
118
118
  def get_output(self, output):
119
119
  s0, s1 = self.shape
nabu/preproc/shift.py CHANGED
@@ -42,7 +42,7 @@ class VerticalShift:
42
42
  def _init_interp_coefficients(self):
43
43
  self.interp_infos = []
44
44
  for s in self.shifts:
45
- s0 = int(floor(s))
45
+ s0 = floor(s)
46
46
  f = s - s0
47
47
  self.interp_infos.append([s0, f])
48
48
 
@@ -51,7 +51,7 @@ class VerticalShift:
51
51
  assert np.max(iangles) < len(self.interp_infos)
52
52
  assert len(iangles) == radios.shape[0]
53
53
 
54
- def apply_vertical_shifts(self, radios, iangles, output=None):
54
+ def apply_vertical_shifts(self, radios, iangles=None, output=None):
55
55
  """
56
56
  Parameters
57
57
  ----------
@@ -65,6 +65,8 @@ class VerticalShift:
65
65
  If given, it will be modified to contain the shifted radios.
66
66
  Must be of the same shape of `radios`.
67
67
  """
68
+ if iangles is None:
69
+ iangles = np.arange(radios.shape[0])
68
70
  self._check(radios, iangles)
69
71
 
70
72
  newradio = np.zeros_like(radios[0])
@@ -1,5 +1,4 @@
1
1
  import numpy as np
2
- from ..cuda.utils import __has_pycuda__
3
2
  from ..cuda.processing import CudaProcessing
4
3
  from ..processing.muladd_cuda import CudaMulAdd
5
4
  from .shift import VerticalShift
@@ -10,7 +10,7 @@ from nabu.preproc import ctf
10
10
  from nabu.estimation.distortion import estimate_flat_distortion
11
11
  from nabu.misc.filters import correct_spikes
12
12
  from nabu.preproc.distortion import DistortionCorrection
13
- from nabu.cuda.utils import __has_pycuda__, get_cuda_context
13
+ from nabu.cuda.utils import __has_pycuda__
14
14
 
15
15
  __has_cufft__ = False
16
16
  if __has_pycuda__:
@@ -220,10 +220,10 @@ class TestCtf:
220
220
  # Test multi-core FFT
221
221
  ctf_fft = ctf.CtfFilter(*ctf_args, **ctf_kwargs, use_rfft=True, fft_num_threads=0)
222
222
  if ctf_fft.use_rfft:
223
- phase_fft = ctf_fft.retrieve_phase(img)
223
+ # phase_fft = ctf_fft.retrieve_phase(img)
224
224
  self.check_result(phase_r2c, self.ref_plain, "Something wrong with CtfFilter-FFT")
225
225
 
226
- @pytest.mark.skipif(not (__has_pycuda__ and __has_cufft__), reason="pycuda and (scikit-cuda or vkfft)")
226
+ @pytest.mark.skipif(not (__has_pycuda__ and __has_cufft__), reason="pycuda and (cupy? or vkfft)")
227
227
  def test_cuda_ctf(self):
228
228
  data = nabu_get_data("brain_phantom.npz")["data"]
229
229
  delta_beta = 50.0
@@ -232,7 +232,7 @@ class TestCtf:
232
232
  pix_size_m = 0.1e-6
233
233
 
234
234
  geo_pars = ctf.GeoPars(z2=distance_m, pix_size_det=pix_size_m, wavelength=1.23984199e-9 / energy_kev)
235
- ctx = get_cuda_context()
235
+ # ctx = get_cuda_context()
236
236
 
237
237
  for normalize in [True, False]:
238
238
  ctf_filter = ctf.CTFPhaseRetrieval(
@@ -1,4 +1,4 @@
1
- import os.path as path
1
+ from os import path
2
2
  from math import exp
3
3
  import tempfile
4
4
  import numpy as np
@@ -411,7 +411,7 @@ class FlatFieldTestDataset:
411
411
  self._generate_projections()
412
412
 
413
413
  def get_flat_idx(self, proj_idx):
414
- flats_idx = sorted(list(self.flats.keys()))
414
+ flats_idx = sorted(self.flats.keys())
415
415
  if proj_idx <= flats_idx[0]:
416
416
  return (flats_idx[0],)
417
417
  elif proj_idx > flats_idx[0] and proj_idx < flats_idx[1]:
@@ -77,9 +77,7 @@ class TestPaganin:
77
77
  errmax = np.max(np.abs(res - res_tomopy) / np.max(res_tomopy))
78
78
  assert errmax < self.rtol_pag, "Max error is too high"
79
79
 
80
- @pytest.mark.skipif(
81
- not (__has_pycuda__ and __has_cufft__), reason="Need pycuda and (scikit-cuda or vkfft) for this test"
82
- )
80
+ @pytest.mark.skipif(not (__has_pycuda__ and __has_cufft__), reason="Need pycuda and (cupy? or vkfft) for this test")
83
81
  @pytest.mark.parametrize("config", scenarios)
84
82
  def test_gpu_paganin(self, config):
85
83
  paganin, data, pag_kwargs = self.get_paganin_instance_and_data(config, self.data)