nabu 2022.3.0a1__py3-none-any.whl → 2023.1.0a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. nabu/__init__.py +1 -1
  2. nabu/app/bootstrap.py +7 -1
  3. nabu/app/cast_volume.py +8 -2
  4. nabu/app/cli_configs.py +69 -0
  5. nabu/app/composite_cor.py +97 -0
  6. nabu/app/create_distortion_map_from_poly.py +118 -0
  7. nabu/app/nx_z_splitter.py +1 -1
  8. nabu/app/prepare_weights_double.py +21 -16
  9. nabu/app/reconstruct_helical.py +0 -1
  10. nabu/app/utils.py +10 -5
  11. nabu/cuda/processing.py +1 -0
  12. nabu/cuda/tests/test_padding.py +1 -0
  13. nabu/cuda/utils.py +1 -0
  14. nabu/distributed/__init__.py +0 -0
  15. nabu/distributed/utils.py +57 -0
  16. nabu/distributed/worker.py +543 -0
  17. nabu/estimation/cor.py +3 -7
  18. nabu/estimation/cor_sino.py +2 -1
  19. nabu/estimation/distortion.py +6 -4
  20. nabu/io/cast_volume.py +10 -1
  21. nabu/io/detector_distortion.py +305 -0
  22. nabu/io/reader.py +37 -7
  23. nabu/io/reader_helical.py +0 -3
  24. nabu/io/tests/test_cast_volume.py +16 -4
  25. nabu/io/tests/test_detector_distortion.py +178 -0
  26. nabu/io/tests/test_writers.py +2 -2
  27. nabu/io/tiffwriter_zmm.py +2 -3
  28. nabu/io/writer.py +84 -1
  29. nabu/io/writer_BACKUP_193259.py +556 -0
  30. nabu/io/writer_BACKUP_193381.py +556 -0
  31. nabu/io/writer_BASE_193259.py +548 -0
  32. nabu/io/writer_BASE_193381.py +548 -0
  33. nabu/io/writer_LOCAL_193259.py +550 -0
  34. nabu/io/writer_LOCAL_193381.py +550 -0
  35. nabu/io/writer_REMOTE_193259.py +557 -0
  36. nabu/io/writer_REMOTE_193381.py +557 -0
  37. nabu/misc/fourier_filters.py +2 -0
  38. nabu/misc/rotation.py +0 -1
  39. nabu/misc/tests/test_rotation.py +1 -0
  40. nabu/pipeline/config_validators.py +10 -0
  41. nabu/pipeline/datadump.py +1 -1
  42. nabu/pipeline/dataset_validator.py +0 -1
  43. nabu/pipeline/detector_distortion_provider.py +20 -0
  44. nabu/pipeline/estimators.py +35 -21
  45. nabu/pipeline/fallback_utils.py +1 -1
  46. nabu/pipeline/fullfield/chunked.py +30 -15
  47. nabu/pipeline/fullfield/chunked_black.py +881 -0
  48. nabu/pipeline/fullfield/chunked_cuda.py +34 -4
  49. nabu/pipeline/fullfield/chunked_fb.py +966 -0
  50. nabu/pipeline/fullfield/chunked_google.py +921 -0
  51. nabu/pipeline/fullfield/chunked_pep8.py +920 -0
  52. nabu/pipeline/fullfield/computations.py +7 -6
  53. nabu/pipeline/fullfield/dataset_validator.py +1 -1
  54. nabu/pipeline/fullfield/grouped_cuda.py +6 -0
  55. nabu/pipeline/fullfield/nabu_config.py +15 -3
  56. nabu/pipeline/fullfield/processconfig.py +5 -0
  57. nabu/pipeline/fullfield/reconstruction.py +1 -2
  58. nabu/pipeline/helical/gridded_accumulator.py +1 -8
  59. nabu/pipeline/helical/helical_chunked_regridded.py +48 -33
  60. nabu/pipeline/helical/helical_reconstruction.py +1 -9
  61. nabu/pipeline/helical/nabu_config.py +11 -14
  62. nabu/pipeline/helical/span_strategy.py +11 -4
  63. nabu/pipeline/helical/tests/test_accumulator.py +0 -3
  64. nabu/pipeline/helical/tests/test_pipeline_elements_full.py +0 -6
  65. nabu/pipeline/helical/tests/test_strategy.py +0 -1
  66. nabu/pipeline/helical/weight_balancer.py +0 -1
  67. nabu/pipeline/params.py +4 -0
  68. nabu/pipeline/processconfig.py +6 -2
  69. nabu/pipeline/writer.py +9 -4
  70. nabu/preproc/distortion.py +4 -3
  71. nabu/preproc/double_flatfield.py +16 -4
  72. nabu/preproc/double_flatfield_cuda.py +3 -2
  73. nabu/preproc/double_flatfield_variable_region.py +13 -4
  74. nabu/preproc/flatfield.py +29 -7
  75. nabu/preproc/flatfield_cuda.py +0 -1
  76. nabu/preproc/flatfield_variable_region.py +5 -2
  77. nabu/preproc/phase.py +0 -1
  78. nabu/preproc/phase_cuda.py +0 -1
  79. nabu/preproc/tests/test_ctf.py +4 -3
  80. nabu/preproc/tests/test_flatfield.py +6 -7
  81. nabu/reconstruction/fbp_opencl.py +1 -1
  82. nabu/reconstruction/filtering.py +0 -1
  83. nabu/reconstruction/tests/test_fbp.py +1 -0
  84. nabu/resources/dataset_analyzer.py +0 -1
  85. nabu/resources/templates/bm05_pag.conf +34 -0
  86. nabu/resources/templates/id16_ctf.conf +2 -1
  87. nabu/resources/tests/test_nxflatfield.py +0 -1
  88. nabu/resources/tests/test_units.py +0 -1
  89. nabu/stitching/frame_composition.py +7 -1
  90. {nabu-2022.3.0a1.dist-info → nabu-2023.1.0a2.dist-info}/METADATA +2 -7
  91. {nabu-2022.3.0a1.dist-info → nabu-2023.1.0a2.dist-info}/RECORD +96 -75
  92. {nabu-2022.3.0a1.dist-info → nabu-2023.1.0a2.dist-info}/WHEEL +1 -1
  93. {nabu-2022.3.0a1.dist-info → nabu-2023.1.0a2.dist-info}/entry_points.txt +2 -1
  94. {nabu-2022.3.0a1.dist-info → nabu-2023.1.0a2.dist-info}/LICENSE +0 -0
  95. {nabu-2022.3.0a1.dist-info → nabu-2023.1.0a2.dist-info}/top_level.txt +0 -0
  96. {nabu-2022.3.0a1.dist-info → nabu-2023.1.0a2.dist-info}/zip-safe +0 -0
@@ -26,7 +26,7 @@ class SpanStrategy:
26
26
  ----------
27
27
  z_pix_per_proj : array of floats
28
28
  an array of floats with one entry per projection, in pixel units. The values are the vertical displacements of the detector.
29
- An increasing z means that the rotation axis is following the positive direction of the detector vertical axis, which is pointing toward the ground.
29
+ An decreasing z means that the rotation axis is following the positive direction of the detector vertical axis, which is pointing toward the ground.
30
30
  In the experimental setup, the vertical detector axis is pointing toward the ground. Moreover the values are offsetted so that the
31
31
  first value is zero. The offset value, in millimiters is z_offset_mm and it is the vertical position of the sample stage relatively
32
32
  to the center of the detector. A negative z_offset_mm means that the sample stage is below the detector for the first projection, and this is almost
@@ -186,9 +186,18 @@ class SpanStrategy:
186
186
 
187
187
  def get_informative_string(self):
188
188
  doable_span_v = self.get_doable_span()
189
+ if self.z_pix_per_proj[-1] > self.z_pix_per_proj[-1]:
190
+ direction = "ascending"
191
+ else:
192
+ direction = "descending"
193
+
189
194
  s = f"""
190
195
  Doable vertical span
191
196
  --------------------
197
+ The scan has been performed with an {direction} vertical translation of the rotation axis.
198
+
199
+ The detector vertical axis is up side down.
200
+
192
201
  Detector reference system at iproj=0:
193
202
  from vertical view height ... {doable_span_v.view_heights_minmax[0]}
194
203
  up to (included) ... {doable_span_v.view_heights_minmax[1]}
@@ -196,9 +205,7 @@ class SpanStrategy:
196
205
  The slice that projects to the first line of the first projection
197
206
  corresponds to vertical heigth = 0
198
207
 
199
- In the sample stage reference system:
200
- from vertical height above stage ( pixel units) ... {doable_span_v.z_pix_minmax[0]}
201
- up to (included) ... {doable_span_v.z_pix_minmax[1]}
208
+ In voxels, the vertical doable span measures: {doable_span_v.z_pix_minmax[1] - doable_span_v.z_pix_minmax[0]}
202
209
 
203
210
  And in millimiters above the stage:
204
211
  from vertical height above stage ( mm units) ... {doable_span_v.z_mm_minmax[0] - self.z_offset_mm }
@@ -57,7 +57,6 @@ class TestGriddedAccumulator:
57
57
  """
58
58
 
59
59
  def test_regridding(self):
60
-
61
60
  span_info = span_strategy.SpanStrategy(
62
61
  z_pix_per_proj=self.z_pix_per_proj,
63
62
  x_pix_per_proj=self.x_pix_per_proj,
@@ -109,7 +108,6 @@ class TestGriddedAccumulator:
109
108
  pnum_end_list = pnum_start_list[1:] + [proj_num_end]
110
109
 
111
110
  for pnum_start, pnum_end in zip(pnum_start_list, pnum_end_list):
112
-
113
111
  start_in_chunk = pnum_start - my_first_pnum
114
112
  end_in_chunk = pnum_end - my_first_pnum
115
113
 
@@ -127,7 +125,6 @@ class TestGriddedAccumulator:
127
125
  # h5py.File("processed_sinogram.h5","w")["sinogram"] = res
128
126
 
129
127
  def _read_data_and_apply_flats(self, sub_total_prange_slice, subchunk_slice, chunk_info, sub_region, span_info):
130
-
131
128
  my_integer_shifts_v = chunk_info.integer_shift_v[subchunk_slice]
132
129
  fract_complement_shifts_v = chunk_info.fract_complement_to_integer_shift_v[subchunk_slice]
133
130
  x_shifts_list = chunk_info.x_pix_per_proj[subchunk_slice]
@@ -98,7 +98,6 @@ class TestGriddedAccumulator:
98
98
  """
99
99
 
100
100
  def test_regridding(self):
101
-
102
101
  span_info = span_strategy.SpanStrategy(
103
102
  z_pix_per_proj=self.z_pix_per_proj,
104
103
  x_pix_per_proj=self.x_pix_per_proj,
@@ -150,7 +149,6 @@ class TestGriddedAccumulator:
150
149
  pnum_end_list = pnum_start_list[1:] + [proj_num_end]
151
150
 
152
151
  for pnum_start, pnum_end in zip(pnum_start_list, pnum_end_list):
153
-
154
152
  start_in_chunk = pnum_start - my_first_pnum
155
153
  end_in_chunk = pnum_end - my_first_pnum
156
154
 
@@ -205,7 +203,6 @@ class TestGriddedAccumulator:
205
203
  # put the test here
206
204
 
207
205
  def _reconstruct(self):
208
-
209
206
  axis_corrections = np.zeros_like(self.reconstruction_space.gridded_angles_rad)
210
207
  self.reconstruction.set_custom_angles_and_axis_corrections(
211
208
  self.reconstruction_space.gridded_angles_rad, axis_corrections
@@ -222,7 +219,6 @@ class TestGriddedAccumulator:
222
219
  n_provided_angles = self.d_radios_slim.shape[0]
223
220
 
224
221
  for first_angle_index in range(0, n_provided_angles, self.num_weight_radios_per_app):
225
-
226
222
  end_angle_index = min(n_provided_angles, first_angle_index + self.num_weight_radios_per_app)
227
223
  self._d_radios_weights[: end_angle_index - first_angle_index].set(
228
224
  weights[first_angle_index:end_angle_index, i_slice]
@@ -242,7 +238,6 @@ class TestGriddedAccumulator:
242
238
  )
243
239
 
244
240
  def _init_reconstructor(self, processed_radios_shape):
245
-
246
241
  one_slice_data_shape = processed_radios_shape[:1] + processed_radios_shape[2:]
247
242
 
248
243
  self.d_radios_slim = garray.zeros(one_slice_data_shape, np.float32)
@@ -297,7 +292,6 @@ class TestGriddedAccumulator:
297
292
  return processed_radios
298
293
 
299
294
  def _read_data_and_apply_flats(self, sub_total_prange_slice, subchunk_slice, chunk_info, sub_region, span_info):
300
-
301
295
  my_integer_shifts_v = chunk_info.integer_shift_v[subchunk_slice]
302
296
 
303
297
  subr_start_z, subr_end_z = sub_region
@@ -34,7 +34,6 @@ def bootstrap_TestStrategy(request):
34
34
  @pytest.mark.usefixtures("bootstrap_TestStrategy")
35
35
  class TestStrategy:
36
36
  def test_strategy(self):
37
-
38
37
  # the python implementation is slow. so we take only a p[art of the scan
39
38
  limit = 4000
40
39
  span_info = SpanStrategy(
@@ -43,7 +43,6 @@ class WeightBalancer:
43
43
  angle = self.my_angles_rad[i]
44
44
 
45
45
  for i_half_turn in range(-n_span - 1, n_span + 2):
46
-
47
46
  if i_half_turn == 0:
48
47
  w_res[:] += radios_weights[i]
49
48
  continue
nabu/pipeline/params.py CHANGED
@@ -78,6 +78,7 @@ files_formats = {
78
78
  "j2k": "jp2",
79
79
  "jpeg2000": "jp2",
80
80
  "edf": "edf",
81
+ "vol": "vol",
81
82
  }
82
83
 
83
84
  distribution_methods = {
@@ -129,6 +130,9 @@ rings_methods = {
129
130
  "munch": "munch",
130
131
  }
131
132
 
133
+ detector_distortion_correction_methods = {"none": None, "": None, "identity": "identity", "map_xz": "map_xz"}
134
+
135
+
132
136
  radios_rotation_mode = {
133
137
  "none": None,
134
138
  "": None,
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  from .config import parse_nabu_config_file
3
- from ..utils import deprecation_warning
3
+ from ..utils import deprecation_warning, is_writeable
4
4
  from ..resources.logger import Logger, PrinterLogger
5
5
  from .config import validate_config
6
6
  from ..resources.dataset_analyzer import analyze_dataset, _tomoscan_has_nxversion
@@ -92,7 +92,11 @@ class ProcessConfigBase:
92
92
  logger_filename = create_logger
93
93
  else:
94
94
  raise ValueError("Expected bool or str for create_logger")
95
- self.logger = Logger("nabu", level=self.nabu_config["pipeline"]["verbosity"], logfile=logger_filename)
95
+ if not is_writeable(os.path.dirname(logger_filename)):
96
+ self.logger = PrinterLogger()
97
+ self.logger.error("Cannot create logger file %s: no permission to write therein" % logger_filename)
98
+ else:
99
+ self.logger = Logger("nabu", level=self.nabu_config["pipeline"]["verbosity"], logfile=logger_filename)
96
100
 
97
101
  def _parse_configuration(self, conf_fname, conf_dict):
98
102
  """
nabu/pipeline/writer.py CHANGED
@@ -4,7 +4,7 @@ from posixpath import join as posixjoin
4
4
  from silx.io.dictdump import dicttonx
5
5
  from tomoscan.esrf import HDF5Volume, TIFFVolume, MultiTIFFVolume, EDFVolume, JP2KVolume
6
6
  from ..resources.logger import LoggerOrPrint
7
- from ..io.writer import get_datetime, NXProcessWriter
7
+ from ..io.writer import get_datetime, NXProcessWriter, HSTVolVolume
8
8
  from ..io.utils import convert_dict_values
9
9
  from .. import version as nabu_version
10
10
  from ..resources.utils import is_hdf5_extension
@@ -94,6 +94,7 @@ class WriterManager:
94
94
  self.is_bigtiff = file_format in ["tiff", "tif"] and any(
95
95
  [self.extra_options.get(opt, False) for opt in ["tiff_single_file", "use_bigtiff"]]
96
96
  )
97
+ self.is_vol = file_format == "vol"
97
98
 
98
99
  self.file_prefix = file_prefix
99
100
  self._set_output_dir(output_dir)
@@ -127,7 +128,7 @@ class WriterManager:
127
128
  if self.is_bigtiff:
128
129
  writer = MultiTIFFVolume
129
130
  writer_kwargs = _get_writer_kwargs_multi_frames()
130
- writer_kwargs.update({"append": self.extra_options.get("single_tiff_initialized", False)})
131
+ writer_kwargs.update({"append": self.extra_options.get("single_output_file_initialized", False)})
131
132
  else:
132
133
  writer = TIFFVolume
133
134
  writer_kwargs = _get_writer_kwargs_single_frame()
@@ -138,10 +139,14 @@ class WriterManager:
138
139
  "overwrite": True,
139
140
  }
140
141
  )
142
+ elif file_format == "vol":
143
+ writer = HSTVolVolume
144
+ writer_kwargs = _get_writer_kwargs_multi_frames()
145
+ writer_kwargs.update({"append": self.extra_options.get("single_output_file_initialized", False)})
141
146
  elif file_format == "edf":
142
147
  writer = EDFVolume
143
148
  writer_kwargs = _get_writer_kwargs_single_frame()
144
- elif file_format in ["jp2k", "j2k", "jp2000", "jpeg2000"]:
149
+ elif file_format in ["jp2k", "j2k", "jp2", "jp2000", "jpeg2000"]:
145
150
  writer = JP2KVolume
146
151
  writer_kwargs = _get_writer_kwargs_single_frame()
147
152
  else:
@@ -153,7 +158,7 @@ class WriterManager:
153
158
  # This class is generally used to create partial files, i.e files containing a subset of the processed volume.
154
159
  # In this case, the files containing partial results are stored in a sub-directory with the same file prefix.
155
160
  # Otherwise, everything is put in a single file (for now it's only the case for "big tiff").
156
- self.is_partial_file = not (self.is_bigtiff)
161
+ self.is_partial_file = not (self.is_bigtiff or self.is_vol)
157
162
  if self.is_partial_file:
158
163
  output_dir = path.join(output_dir, self.file_prefix)
159
164
 
@@ -1,5 +1,5 @@
1
1
  import numpy as np
2
- from scipy.interpolate import interpn
2
+ from scipy.interpolate import RegularGridInterpolator
3
3
  from ..utils import check_supported
4
4
  from ..estimation.distortion import estimate_flat_distortion
5
5
 
@@ -15,13 +15,14 @@ def correct_distortion_interpn(image, coords, bounds_error=False, fill_value=Non
15
15
  coords: array
16
16
  Coordinates of the distortion correction to apply, with the shape (Ny, Nx, 2)
17
17
  """
18
- return interpn(
18
+ foo = RegularGridInterpolator(
19
19
  (np.arange(image.shape[0]), np.arange(image.shape[1])),
20
20
  image,
21
- coords,
22
21
  bounds_error=bounds_error,
22
+ method="linear",
23
23
  fill_value=fill_value,
24
24
  )
25
+ return foo(coords)
25
26
 
26
27
 
27
28
  class DistortionCorrection:
@@ -8,7 +8,6 @@ from ..io.writer import Writers
8
8
 
9
9
 
10
10
  class DoubleFlatField:
11
-
12
11
  _default_h5_path = "/entry/double_flatfield/results"
13
12
  _small = 1e-7
14
13
 
@@ -17,6 +16,7 @@ class DoubleFlatField:
17
16
  shape,
18
17
  result_url=None,
19
18
  sub_region=None,
19
+ detector_corrector=None,
20
20
  input_is_mlog=True,
21
21
  output_is_mlog=False,
22
22
  average_is_on_log=False,
@@ -55,7 +55,7 @@ class DoubleFlatField:
55
55
  self.radios_shape = get_2D_3D_shape(shape)
56
56
  self.n_angles = self.radios_shape[0]
57
57
  self.shape = self.radios_shape[1:]
58
- self._init_filedump(result_url, sub_region)
58
+ self._init_filedump(result_url, sub_region, detector_corrector)
59
59
  self._init_processing(input_is_mlog, output_is_mlog, average_is_on_log, sigma_filter, filter_mode)
60
60
  self._computed = False
61
61
 
@@ -67,6 +67,13 @@ class DoubleFlatField:
67
67
 
68
68
  def _load_dff_dump(self):
69
69
  res = self.reader.get_data(self.result_url)
70
+ if self.detector_corrector is not None:
71
+ if res.ndim == 2:
72
+ res = self.detector_corrector.transform(res)
73
+ else:
74
+ for i in range(res.shape[0]):
75
+ res[i] = self.detector_corrector.transform(res[i])
76
+
70
77
  if res.ndim == 3 and res.shape[0] == 1:
71
78
  res = res.reshape(res.shape[1], res.shape[2])
72
79
  if res.shape != self.shape:
@@ -76,10 +83,11 @@ class DoubleFlatField:
76
83
  )
77
84
  return res
78
85
 
79
- def _init_filedump(self, result_url, sub_region):
86
+ def _init_filedump(self, result_url, sub_region, detector_corrector=None):
80
87
  if isinstance(result_url, str):
81
88
  result_url = DataUrl(file_path=result_url, data_path=self._default_h5_path)
82
89
  self.sub_region = sub_region
90
+ self.detector_corrector = detector_corrector
83
91
  self.result_url = result_url
84
92
  self.writer = None
85
93
  self.reader = None
@@ -87,7 +95,11 @@ class DoubleFlatField:
87
95
  return
88
96
  self._get_reader_writer_class()
89
97
  if path.exists(result_url.file_path()):
90
- self.reader = self._reader_cls(sub_region=self.sub_region)
98
+ if detector_corrector is None:
99
+ adapted_subregion = sub_region
100
+ else:
101
+ adapted_subregion = self.detector_corrector.get_adapted_subregion(sub_region)
102
+ self.reader = self._reader_cls(sub_region=adapted_subregion)
91
103
  else:
92
104
  self.writer = self._writer_cls(self.result_url.file_path())
93
105
 
@@ -1,7 +1,6 @@
1
- import numpy as np
2
1
  from .double_flatfield import DoubleFlatField
3
2
  from ..utils import check_shape
4
- from ..cuda.utils import get_cuda_context, __has_pycuda__
3
+ from ..cuda.utils import __has_pycuda__
5
4
  from ..cuda.processing import CudaProcessing
6
5
  from ..misc.unsharp_cuda import CudaUnsharpMask
7
6
 
@@ -16,6 +15,7 @@ class CudaDoubleFlatField(DoubleFlatField):
16
15
  shape,
17
16
  result_url=None,
18
17
  sub_region=None,
18
+ detector_corrector=None,
19
19
  input_is_mlog=True,
20
20
  output_is_mlog=False,
21
21
  average_is_on_log=False,
@@ -31,6 +31,7 @@ class CudaDoubleFlatField(DoubleFlatField):
31
31
  shape,
32
32
  result_url=result_url,
33
33
  sub_region=sub_region,
34
+ detector_corrector=detector_corrector,
34
35
  input_is_mlog=input_is_mlog,
35
36
  output_is_mlog=output_is_mlog,
36
37
  average_is_on_log=average_is_on_log,
@@ -8,7 +8,14 @@ from ..misc.binning import get_binning_function
8
8
 
9
9
 
10
10
  class DoubleFlatFieldVariableRegion(DoubleFlatField):
11
- def __init__(self, shape, result_url=None, binning_x=None, binning_z=None):
11
+ def __init__(
12
+ self,
13
+ shape,
14
+ result_url=None,
15
+ binning_x=None,
16
+ binning_z=None,
17
+ detector_corrector=None,
18
+ ):
12
19
  """This class provides the division by the double flat field.
13
20
  At variance with the standard class, it store as member the
14
21
  whole field, and performs the division by the proper region
@@ -20,12 +27,11 @@ class DoubleFlatFieldVariableRegion(DoubleFlatField):
20
27
  self.radios_shape = get_2D_3D_shape(shape)
21
28
  self.n_angles = self.radios_shape[0]
22
29
  self.shape = self.radios_shape[1:]
23
- self._init_filedump(result_url, None)
30
+ self._init_filedump(result_url, None, detector_corrector)
24
31
 
25
32
  data = self._load_dff_full_dump()
26
33
 
27
34
  if (binning_z, binning_x) != (1, 1):
28
-
29
35
  print(" (binning_z, binning_x) ", (binning_z, binning_x))
30
36
  binning_function = get_binning_function((binning_z, binning_x))
31
37
  if binning_function is None:
@@ -38,6 +44,10 @@ class DoubleFlatFieldVariableRegion(DoubleFlatField):
38
44
 
39
45
  def _load_dff_full_dump(self):
40
46
  res = self.reader.get_data(self.result_url)
47
+ if self.detector_corrector is not None:
48
+ self.detector_corrector.set_full_transformation()
49
+ res = self.detector_corrector.transform(res, do_full=True)
50
+
41
51
  return res
42
52
 
43
53
  def apply_double_flatfield_for_sub_regions(self, radios, sub_regions_per_radio):
@@ -48,7 +58,6 @@ class DoubleFlatFieldVariableRegion(DoubleFlatField):
48
58
  my_double_ff = self.data
49
59
 
50
60
  for i in range(radios.shape[0]):
51
-
52
61
  s_x, e_x, s_y, e_y = sub_regions_per_radio[i]
53
62
 
54
63
  dff = my_double_ff[s_y:e_y, s_x:e_x]
nabu/preproc/flatfield.py CHANGED
@@ -1,7 +1,8 @@
1
+ from multiprocessing.pool import ThreadPool
1
2
  from bisect import bisect_left
2
3
  import numpy as np
3
4
  from ..io.reader import load_images_from_dataurl_dict
4
- from ..utils import check_supported
5
+ from ..utils import check_supported, get_num_threads
5
6
 
6
7
 
7
8
  class FlatFieldArrays:
@@ -26,6 +27,7 @@ class FlatFieldArrays:
26
27
  nan_value=1.0,
27
28
  radios_srcurrent=None,
28
29
  flats_srcurrent=None,
30
+ n_threads=None,
29
31
  ):
30
32
  """
31
33
  Initialize a flat-field normalization process.
@@ -59,6 +61,8 @@ class FlatFieldArrays:
59
61
  for the corresponding flat. The items must be ordered in the same order as the flats indices (`flats.keys()`).
60
62
  This parameter must be used along with 'radios_srcurrent'.
61
63
  Please refer to "Notes" for more information on this normalization.
64
+ n_threads: int or None, optional
65
+ Number of threads to use for flat-field correction. Default is to use half the threads.
62
66
 
63
67
  Important
64
68
  ----------
@@ -98,6 +102,7 @@ class FlatFieldArrays:
98
102
  self._precompute_flats_indices_weights()
99
103
  self._configure_srcurrent_normalization(radios_srcurrent, flats_srcurrent)
100
104
  self.distortion_correction = distortion_correction
105
+ self.n_threads = min(1, get_num_threads(n_threads) // 2)
101
106
 
102
107
  def _set_parameters(self, radios_shape, radios_indices, interpolation, nan_value):
103
108
  self._set_radios_shape(radios_shape)
@@ -147,8 +152,7 @@ class FlatFieldArrays:
147
152
  self._sorted_dark_indices = sorted(self.darks.keys())
148
153
  self._dark = None
149
154
 
150
- @staticmethod
151
- def _check_frames(frames, frames_type, min_frames_required, max_frames_supported):
155
+ def _check_frames(self, frames, frames_type, min_frames_required, max_frames_supported):
152
156
  n_frames = len(frames)
153
157
  if n_frames < min_frames_required:
154
158
  raise ValueError("Need at least %d %s" % (min_frames_required, frames_type))
@@ -156,6 +160,15 @@ class FlatFieldArrays:
156
160
  raise ValueError(
157
161
  "Flat-fielding with more than %d %s is not supported" % (max_frames_supported, frames_type)
158
162
  )
163
+ self._check_frame_shape(frames, frames_type)
164
+
165
+ def _check_frame_shape(self, frames, frames_type):
166
+ for frame_idx, frame in frames.items():
167
+ if frame.shape != self.shape:
168
+ raise ValueError(
169
+ "Invalid shape for %s %s: expected %s, but got %s"
170
+ % (frames_type, frame_idx, str(self.shape), str(frame.shape))
171
+ )
159
172
 
160
173
  def _check_radios_and_indices_congruence(self, radios_indices):
161
174
  if radios_indices.size != self.n_radios:
@@ -318,15 +331,24 @@ class FlatFieldArrays:
318
331
  """
319
332
  do_flats_distortion_correction = self.distortion_correction is not None
320
333
  dark = self.get_dark()
321
- for i in range(self.n_radios):
334
+
335
+ def apply_flatfield(i):
322
336
  radio_data = radios[i]
323
337
  radio_data -= dark
324
338
  flat = self.get_flat(i)
325
339
  flat = flat - dark
326
340
  if do_flats_distortion_correction:
327
341
  flat = self.distortion_correction.estimate_and_correct(flat, radio_data)
328
- radios[i] = radio_data / flat
329
- self.remove_invalid_values(radios[i])
342
+ np.divide(radio_data, flat, out=radio_data)
343
+ self.remove_invalid_values(radio_data)
344
+
345
+ if self.n_threads > 2:
346
+ with ThreadPool(self.n_threads) as tp:
347
+ tp.map(apply_flatfield, range(self.n_radios))
348
+ else:
349
+ for i in range(self.n_radios):
350
+ apply_flatfield(i)
351
+
330
352
  if self.normalize_srcurrent:
331
353
  radios *= self.srcurrent_ratios[:, np.newaxis, np.newaxis]
332
354
  return radios
@@ -388,7 +410,7 @@ class FlatFieldDataUrls(FlatField):
388
410
  interpolation: str, optional
389
411
  Interpolation method for flat-field. See below for more details.
390
412
  distortion_correction: DistortionCorrection, optional
391
- A DistortionCorrection object. If provided, it is used to
413
+ A DistortionCorrection object. If provided, it is used to correct flat distortions based on each radio.
392
414
  nan_value: float, optional
393
415
  Which float value is used to replace nan/inf after flat-field.
394
416
 
@@ -133,7 +133,6 @@ class CudaFlatFieldDataUrls(CudaFlatField):
133
133
  cuda_options: Union[dict, None] = None,
134
134
  **chunk_reader_kwargs,
135
135
  ):
136
-
137
136
  flats_arrays_dict = load_images_from_dataurl_dict(flats, **chunk_reader_kwargs)
138
137
  darks_arrays_dict = load_images_from_dataurl_dict(darks, **chunk_reader_kwargs)
139
138
  super().__init__(
@@ -3,9 +3,13 @@ from .flatfield import FlatFieldArrays, load_images_from_dataurl_dict, check_sup
3
3
 
4
4
 
5
5
  class FlatFieldArraysVariableRegion(FlatFieldArrays):
6
-
7
6
  _full_shape = True
8
7
 
8
+ def _check_frame_shape(self, frames, frames_type):
9
+ # in helical the flat is the whole one and its shape does not necesseraly match the smaller frames.
10
+ # Therefore no check is done to allow this.
11
+ pass
12
+
9
13
  def _check_radios_and_indices_congruence(self, radios_indices):
10
14
  """At variance with parent class, preprocesing is done with on a fraction of the radios,
11
15
  whose length may vary. So we dont enforce here that the lenght is always the same
@@ -50,7 +54,6 @@ class FlatFieldDataVariableRegionUrls(FlatFieldArraysVariableRegion):
50
54
  flats_srcurrent=None,
51
55
  **chunk_reader_kwargs,
52
56
  ):
53
-
54
57
  flats_arrays_dict = load_images_from_dataurl_dict(flats, **chunk_reader_kwargs)
55
58
  darks_arrays_dict = load_images_from_dataurl_dict(darks, **chunk_reader_kwargs)
56
59
 
nabu/preproc/phase.py CHANGED
@@ -38,7 +38,6 @@ def lmicron_to_db(Lmicron, energy, distance):
38
38
 
39
39
 
40
40
  class PaganinPhaseRetrieval:
41
-
42
41
  available_padding_modes = ["zeros", "mean", "edge", "symmetric", "reflect"]
43
42
  powers = generate_powers()
44
43
 
@@ -9,7 +9,6 @@ from ..cuda.kernel import CudaKernel
9
9
 
10
10
 
11
11
  class CudaPaganinPhaseRetrieval(PaganinPhaseRetrieval):
12
-
13
12
  supported_paddings = ["zeros", "constant", "edge"]
14
13
 
15
14
  def __init__(
@@ -54,7 +54,6 @@ class TestCtf:
54
54
  diff[diff > np.percentile(diff, 99)] = 0
55
55
  assert diff.max() < self.abs_tol * (np.abs(ref).mean()), error_message
56
56
 
57
- @pytest.mark.skipif(True, reason="wait for scipy.interpolate calls to be fixed")
58
57
  def test_ctf_id16_way(self):
59
58
  """test the ctf phase retrieval.
60
59
  The cft filter, of the CtfFilter class is iniitalised with the geomety informations contained in geo_pars object
@@ -95,13 +94,15 @@ class TestCtf:
95
94
  correction_spike_threshold=3,
96
95
  )
97
96
 
98
- my_flat = scipy.interpolate.interpn(
97
+ interpolator = scipy.interpolate.RegularGridInterpolator(
99
98
  (np.arange(my_flat.shape[0]), np.arange(my_flat.shape[1])),
100
99
  my_flat,
101
- new_coordinates,
102
100
  bounds_error=False,
101
+ method="linear",
103
102
  fill_value=None,
104
103
  )
104
+ my_flat = interpolator(new_coordinates)
105
+
105
106
  my_img = my_img / my_flat
106
107
  my_img = correct_spikes(my_img, self.remove_spikes_threshold)
107
108
 
@@ -131,18 +131,18 @@ def generate_test_flatfield_generalized(
131
131
  flats = {}
132
132
  flats_urls = {}
133
133
  for i, flat_idx in enumerate(flats_indices):
134
- flats["flats_%04d" % flat_idx] = np.zeros(img_shape, dtype=dtype) + flats_values[i]
134
+ flats["flats_%06d" % flat_idx] = np.zeros(img_shape, dtype=dtype) + flats_values[i]
135
135
  flats_urls[flat_idx] = DataUrl(
136
- file_path=testffname, data_path=str("/flats/flats_%04d" % flat_idx), scheme="silx"
136
+ file_path=testffname, data_path=str("/flats/flats_%06d" % flat_idx), scheme="silx"
137
137
  )
138
138
 
139
139
  # Darks
140
140
  darks = {}
141
141
  darks_urls = {}
142
142
  for i, dark_idx in enumerate(darks_indices):
143
- darks["darks_%04d" % dark_idx] = np.zeros(img_shape, dtype=dtype) + darks_values[i]
143
+ darks["darks_%06d" % dark_idx] = np.zeros(img_shape, dtype=dtype) + darks_values[i]
144
144
  darks_urls[dark_idx] = DataUrl(
145
- file_path=testffname, data_path=str("/darks/darks_%04d" % dark_idx), scheme="silx"
145
+ file_path=testffname, data_path=str("/darks/darks_%06d" % dark_idx), scheme="silx"
146
146
  )
147
147
 
148
148
  dicttoh5(flats, testffname, h5path="/flats", mode="w")
@@ -455,7 +455,6 @@ class FlatFieldTestDataset:
455
455
  self.projs_data = np.zeros((len(self.projs_idx),) + self.shp, "f")
456
456
  self.projs = {}
457
457
  for i, proj_idx in enumerate(self.projs_idx):
458
-
459
458
  flat = self.get_flat(proj_idx)
460
459
 
461
460
  proj_val = self.dark_val + proj_idx * (flat[0, 0] - self.dark_val)
@@ -562,8 +561,8 @@ def generate_test_flatfield(n_radios, radio_shape, flat_interval, h5_fname):
562
561
  for i in range(n_radios):
563
562
  f_i = i + 2
564
563
  if (i % flat_interval) == 0:
565
- flats["flats_%04d" % i] = np.zeros(radio_shape, "f") + f_i
566
- flats_urls[i] = DataUrl(file_path=testffname, data_path=str("/flats/flats_%04d" % i), scheme="silx")
564
+ flats["flats_%06d" % i] = np.zeros(radio_shape, "f") + f_i
565
+ flats_urls[i] = DataUrl(file_path=testffname, data_path=str("/flats/flats_%06d" % i), scheme="silx")
567
566
  radios[i] = i * (f_i - 1) + 1
568
567
  dark = {"dark_0000": dark_data}
569
568
  dicttoh5(flats, testffname, h5path="/flats", mode="w")
@@ -2,6 +2,7 @@ import numpy as np
2
2
  from silx.opencl.backprojection import Backprojection
3
3
  from ..utils import deprecation_warning
4
4
 
5
+
5
6
  # Compatibility layer Nabu/silx
6
7
  class Backprojector:
7
8
  def __init__(
@@ -21,7 +22,6 @@ class Backprojector:
21
22
  profile=False,
22
23
  extra_options=None,
23
24
  ):
24
-
25
25
  if slice_roi and (
26
26
  slice_roi[0] > 0 or slice_roi[2] > 0 or slice_roi[1] < sino_shape[1] or slice_roi[3] < sino_shape[1]
27
27
  ):
@@ -8,7 +8,6 @@ from ..utils import get_cuda_srcfile, check_supported, updiv
8
8
 
9
9
 
10
10
  class SinoFilter:
11
-
12
11
  available_padding_modes = ["zeros", "edges"]
13
12
 
14
13
  def __init__(
@@ -85,6 +85,7 @@ class TestFBP:
85
85
  Test the "axis correction" feature
86
86
  """
87
87
  sino = self.sino_512
88
+
88
89
  # Create a sinogram with a drift in the rotation axis
89
90
  def create_drifted_sino(sino, drifts):
90
91
  out = np.zeros_like(sino)
@@ -19,7 +19,6 @@ _tomoscan_has_nxversion = parse_version(tomoscan_version) > parse_version("0.6.0
19
19
 
20
20
 
21
21
  class DatasetAnalyzer:
22
-
23
22
  _scanner = None
24
23
  kind = "none"
25
24