nabu 2024.1.9__py3-none-any.whl → 2024.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (151) hide show
  1. nabu/__init__.py +1 -1
  2. nabu/app/bootstrap.py +2 -3
  3. nabu/app/cast_volume.py +4 -2
  4. nabu/app/cli_configs.py +5 -0
  5. nabu/app/composite_cor.py +1 -1
  6. nabu/app/create_distortion_map_from_poly.py +5 -6
  7. nabu/app/diag_to_pix.py +7 -19
  8. nabu/app/diag_to_rot.py +14 -29
  9. nabu/app/double_flatfield.py +32 -44
  10. nabu/app/parse_reconstruction_log.py +3 -0
  11. nabu/app/reconstruct.py +53 -15
  12. nabu/app/reconstruct_helical.py +2 -2
  13. nabu/app/stitching.py +27 -13
  14. nabu/app/tests/test_reduce_dark_flat.py +4 -1
  15. nabu/cuda/kernel.py +11 -2
  16. nabu/cuda/processing.py +2 -2
  17. nabu/cuda/src/cone.cu +77 -0
  18. nabu/cuda/src/hierarchical_backproj.cu +271 -0
  19. nabu/cuda/utils.py +0 -6
  20. nabu/estimation/alignment.py +5 -19
  21. nabu/estimation/cor.py +173 -599
  22. nabu/estimation/cor_sino.py +356 -26
  23. nabu/estimation/focus.py +63 -11
  24. nabu/estimation/tests/test_cor.py +124 -58
  25. nabu/estimation/tests/test_focus.py +6 -6
  26. nabu/estimation/tilt.py +2 -1
  27. nabu/estimation/utils.py +5 -33
  28. nabu/io/__init__.py +1 -1
  29. nabu/io/cast_volume.py +1 -1
  30. nabu/io/reader.py +416 -21
  31. nabu/io/tests/test_readers.py +422 -0
  32. nabu/io/tests/test_writers.py +1 -102
  33. nabu/io/writer.py +4 -433
  34. nabu/opencl/kernel.py +14 -3
  35. nabu/opencl/processing.py +8 -0
  36. nabu/pipeline/config_validators.py +5 -2
  37. nabu/pipeline/datadump.py +12 -5
  38. nabu/pipeline/estimators.py +162 -188
  39. nabu/pipeline/fullfield/chunked.py +168 -92
  40. nabu/pipeline/fullfield/chunked_cuda.py +7 -3
  41. nabu/pipeline/fullfield/computations.py +2 -7
  42. nabu/pipeline/fullfield/dataset_validator.py +0 -4
  43. nabu/pipeline/fullfield/nabu_config.py +37 -13
  44. nabu/pipeline/fullfield/processconfig.py +22 -13
  45. nabu/pipeline/fullfield/reconstruction.py +13 -9
  46. nabu/pipeline/helical/helical_chunked_regridded.py +1 -1
  47. nabu/pipeline/helical/helical_chunked_regridded_cuda.py +1 -0
  48. nabu/pipeline/helical/helical_reconstruction.py +1 -1
  49. nabu/pipeline/params.py +21 -1
  50. nabu/pipeline/processconfig.py +1 -12
  51. nabu/pipeline/reader.py +146 -0
  52. nabu/pipeline/tests/test_estimators.py +44 -72
  53. nabu/pipeline/utils.py +4 -2
  54. nabu/pipeline/writer.py +10 -2
  55. nabu/preproc/ccd_cuda.py +1 -1
  56. nabu/preproc/ctf.py +14 -7
  57. nabu/preproc/ctf_cuda.py +2 -3
  58. nabu/preproc/double_flatfield.py +5 -12
  59. nabu/preproc/double_flatfield_cuda.py +2 -2
  60. nabu/preproc/flatfield.py +5 -1
  61. nabu/preproc/flatfield_cuda.py +5 -1
  62. nabu/preproc/phase.py +24 -73
  63. nabu/preproc/phase_cuda.py +5 -8
  64. nabu/preproc/tests/test_ctf.py +11 -7
  65. nabu/preproc/tests/test_flatfield.py +67 -122
  66. nabu/preproc/tests/test_paganin.py +54 -30
  67. nabu/processing/azim.py +206 -0
  68. nabu/processing/convolution_cuda.py +1 -1
  69. nabu/processing/fft_cuda.py +15 -17
  70. nabu/processing/histogram.py +2 -0
  71. nabu/processing/histogram_cuda.py +2 -1
  72. nabu/processing/kernel_base.py +3 -0
  73. nabu/processing/muladd_cuda.py +1 -0
  74. nabu/processing/padding_opencl.py +1 -1
  75. nabu/processing/roll_opencl.py +1 -0
  76. nabu/processing/rotation_cuda.py +2 -2
  77. nabu/processing/tests/test_fft.py +17 -10
  78. nabu/processing/unsharp_cuda.py +1 -1
  79. nabu/reconstruction/cone.py +104 -40
  80. nabu/reconstruction/fbp.py +3 -0
  81. nabu/reconstruction/fbp_base.py +7 -2
  82. nabu/reconstruction/filtering.py +20 -7
  83. nabu/reconstruction/filtering_cuda.py +7 -1
  84. nabu/reconstruction/hbp.py +424 -0
  85. nabu/reconstruction/mlem.py +99 -0
  86. nabu/reconstruction/reconstructor.py +2 -0
  87. nabu/reconstruction/rings_cuda.py +19 -19
  88. nabu/reconstruction/sinogram_cuda.py +1 -0
  89. nabu/reconstruction/sinogram_opencl.py +3 -1
  90. nabu/reconstruction/tests/test_cone.py +10 -5
  91. nabu/reconstruction/tests/test_deringer.py +7 -6
  92. nabu/reconstruction/tests/test_fbp.py +124 -10
  93. nabu/reconstruction/tests/test_filtering.py +13 -11
  94. nabu/reconstruction/tests/test_halftomo.py +30 -4
  95. nabu/reconstruction/tests/test_mlem.py +91 -0
  96. nabu/reconstruction/tests/test_reconstructor.py +8 -3
  97. nabu/resources/dataset_analyzer.py +142 -92
  98. nabu/resources/gpu.py +1 -0
  99. nabu/resources/nxflatfield.py +134 -125
  100. nabu/resources/templates/id16a_fluo.conf +42 -0
  101. nabu/resources/tests/test_extract.py +10 -0
  102. nabu/resources/tests/test_nxflatfield.py +2 -2
  103. nabu/stitching/alignment.py +80 -24
  104. nabu/stitching/config.py +105 -68
  105. nabu/stitching/definitions.py +1 -0
  106. nabu/stitching/frame_composition.py +68 -60
  107. nabu/stitching/overlap.py +91 -51
  108. nabu/stitching/single_axis_stitching.py +32 -0
  109. nabu/stitching/slurm_utils.py +6 -6
  110. nabu/stitching/stitcher/__init__.py +0 -0
  111. nabu/stitching/stitcher/base.py +124 -0
  112. nabu/stitching/stitcher/dumper/__init__.py +3 -0
  113. nabu/stitching/stitcher/dumper/base.py +94 -0
  114. nabu/stitching/stitcher/dumper/postprocessing.py +356 -0
  115. nabu/stitching/stitcher/dumper/preprocessing.py +60 -0
  116. nabu/stitching/stitcher/post_processing.py +555 -0
  117. nabu/stitching/stitcher/pre_processing.py +1068 -0
  118. nabu/stitching/stitcher/single_axis.py +484 -0
  119. nabu/stitching/stitcher/stitcher.py +0 -0
  120. nabu/stitching/stitcher/y_stitcher.py +13 -0
  121. nabu/stitching/stitcher/z_stitcher.py +45 -0
  122. nabu/stitching/stitcher_2D.py +278 -0
  123. nabu/stitching/tests/test_config.py +12 -37
  124. nabu/stitching/tests/test_frame_composition.py +33 -59
  125. nabu/stitching/tests/test_overlap.py +149 -7
  126. nabu/stitching/tests/test_utils.py +1 -1
  127. nabu/stitching/tests/test_y_preprocessing_stitching.py +132 -0
  128. nabu/stitching/tests/{test_z_stitching.py → test_z_postprocessing_stitching.py} +167 -561
  129. nabu/stitching/tests/test_z_preprocessing_stitching.py +431 -0
  130. nabu/stitching/utils/__init__.py +1 -0
  131. nabu/stitching/utils/post_processing.py +281 -0
  132. nabu/stitching/utils/tests/test_post-processing.py +21 -0
  133. nabu/stitching/{utils.py → utils/utils.py} +79 -52
  134. nabu/stitching/y_stitching.py +27 -0
  135. nabu/stitching/z_stitching.py +32 -2263
  136. nabu/testutils.py +1 -152
  137. nabu/thirdparty/tomocupy_remove_stripe.py +43 -9
  138. nabu/utils.py +158 -61
  139. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/METADATA +10 -3
  140. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/RECORD +144 -121
  141. nabu/io/tiffwriter_zmm.py +0 -99
  142. nabu/pipeline/fallback_utils.py +0 -149
  143. nabu/pipeline/helical/tests/test_accumulator.py +0 -158
  144. nabu/pipeline/helical/tests/test_pipeline_elements_full.py +0 -355
  145. nabu/pipeline/helical/tests/test_strategy.py +0 -61
  146. nabu/pipeline/helical/utils.py +0 -51
  147. nabu/pipeline/tests/test_chunk_reader.py +0 -74
  148. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/LICENSE +0 -0
  149. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/WHEEL +0 -0
  150. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/entry_points.txt +0 -0
  151. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,14 @@
1
1
  import os
2
- from bisect import bisect_left
3
2
  import numpy as np
4
- from silx.io import get_data
5
3
  from silx.io.url import DataUrl
4
+ from silx.io import get_data
6
5
  from tomoscan.esrf.scan.edfscan import EDFTomoScan
7
6
  from tomoscan.esrf.scan.nxtomoscan import NXtomoScan
8
- from ..utils import check_supported
7
+
8
+ from ..utils import check_supported, indices_to_slices
9
+ from ..io.reader import EDFStackReader, NXDarksFlats, NXTomoReader
9
10
  from ..io.utils import get_compacted_dataslices
10
- from .utils import is_hdf5_extension, get_values_from_file
11
+ from .utils import get_values_from_file, is_hdf5_extension
11
12
  from .logger import LoggerOrPrint
12
13
 
13
14
  from ..pipeline.utils import nabu_env_settings
@@ -51,7 +52,7 @@ class DatasetAnalyzer:
51
52
  "output_dir": None,
52
53
  "exclude_projections": None,
53
54
  "hdf5_entry": None,
54
- "nx_version": 1.0,
55
+ # "nx_version": 1.0,
55
56
  }
56
57
  # --
57
58
  advanced_options.update(extra_options)
@@ -59,16 +60,18 @@ class DatasetAnalyzer:
59
60
 
60
61
  # pylint: disable=E1136
61
62
  def _get_excluded_projections(self):
62
- self._ignore_projections_indices = None
63
- self._need_rebuild_tomoscan_object_to_exclude_projections = False
64
63
  excluded_projs = self.extra_options["exclude_projections"]
64
+ self._ignore_projections = None
65
65
  if excluded_projs is None:
66
66
  return
67
- if excluded_projs["type"] == "indices":
68
- projs_idx = get_values_from_file(excluded_projs["file"], any_size=True).astype(np.int32).tolist()
69
- self._ignore_projections_indices = projs_idx
70
- else:
71
- self._need_rebuild_tomoscan_object_to_exclude_projections = True
67
+
68
+ if excluded_projs["type"] == "angular_range":
69
+ excluded_projs["type"] = "range" # compat with tomoscan #pylint: disable=E1137
70
+ values = excluded_projs["range"]
71
+ for ignore_kind, dtype in {"indices": np.int32, "angles": np.float32}.items():
72
+ if excluded_projs["type"] == ignore_kind:
73
+ values = get_values_from_file(excluded_projs["file"], any_size=True).astype(dtype).tolist()
74
+ self._ignore_projections = {"kind": excluded_projs["type"], "values": values} # pylint: disable=E0606
72
75
 
73
76
  def _init_dataset_scan(self, **kwargs):
74
77
  if self._scanner is None:
@@ -83,40 +86,11 @@ class DatasetAnalyzer:
83
86
  kwargs["n_frames"] = 1
84
87
 
85
88
  self.dataset_scanner = self._scanner( # pylint: disable=E1102
86
- self.location, ignore_projections=self._ignore_projections_indices, **kwargs
89
+ self.location, ignore_projections=self._ignore_projections, **kwargs
87
90
  )
88
- self.projections = self.dataset_scanner.projections
89
-
90
- # ---
91
- if self._need_rebuild_tomoscan_object_to_exclude_projections:
92
- # pylint: disable=E1136
93
- exclude_projs = self.extra_options["exclude_projections"]
94
- rot_angles_deg = np.rad2deg(self.rotation_angles)
95
- self._rotation_angles = None # prevent caching
96
- # tomoscan only supports ignore_projections=<list of integers>
97
- # However this is cumbersome to use, it's more convenient to use angular range or list of angles
98
- # But having angles instead of indices implies to already have information on current scan angular range
99
- ignore_projections_indices = []
100
- if exclude_projs["type"] == "angular_range":
101
- exclude_angle_min, exclude_angle_max = exclude_projs["range"]
102
- projections_indices = np.array(sorted(self.dataset_scanner.projections.keys()))
103
- for proj_idx, angle in zip(projections_indices, rot_angles_deg):
104
- if exclude_angle_min <= angle and angle <= exclude_angle_max:
105
- ignore_projections_indices.append(proj_idx)
106
- elif exclude_projs["type"] == "angles":
107
- excluded_angles = get_values_from_file(exclude_projs["file"], any_size=True).astype(np.float32).tolist()
108
- for excluded_angle in excluded_angles:
109
- proj_idx = bisect_left(rot_angles_deg, excluded_angle)
110
- if proj_idx < rot_angles_deg.size:
111
- ignore_projections_indices.append(proj_idx)
112
- # Rebuild the dataset_scanner instance
113
- self._ignore_projections_indices = ignore_projections_indices
114
- self.dataset_scanner = self._scanner( # pylint: disable=E1102
115
- self.location, ignore_projections=self._ignore_projections_indices, **kwargs
116
- )
117
- # ---
118
- if self._ignore_projections_indices is not None:
119
- self.logger.info("Excluding projections: %s" % str(self._ignore_projections_indices))
91
+
92
+ if self._ignore_projections is not None:
93
+ self.logger.info("Excluding projections: %s" % str(self._ignore_projections))
120
94
 
121
95
  if nabu_env_settings.skip_tomoscan_checks:
122
96
  self.logger.warning(
@@ -124,9 +98,8 @@ class DatasetAnalyzer:
124
98
  )
125
99
  self.dataset_scanner.set_check_behavior(run_check=False, raise_error=False)
126
100
 
127
- self.projections = self.dataset_scanner.projections
128
- self.flats = self.dataset_scanner.flats
129
- self.darks = self.dataset_scanner.darks
101
+ self.raw_flats = self.dataset_scanner.flats
102
+ self.raw_darks = self.dataset_scanner.darks
130
103
  self.n_angles = len(self.dataset_scanner.projections)
131
104
  self.radio_dims = (self.dataset_scanner.dim_1, self.dataset_scanner.dim_2)
132
105
  self._radio_dims_notbinned = self.radio_dims # COMPAT
@@ -146,7 +119,10 @@ class DatasetAnalyzer:
146
119
  self._pixel_size = None
147
120
  self._distance = None
148
121
  self._flats_srcurrent = None
122
+ self._projections = None
149
123
  self._projections_srcurrent = None
124
+ self._reduced_flats = None
125
+ self._reduced_darks = None
150
126
 
151
127
  @property
152
128
  def energy(self):
@@ -223,15 +199,19 @@ class DatasetAnalyzer:
223
199
  def detector_tilt(self, tilt):
224
200
  self._detector_tilt = tilt
225
201
 
226
- def _get_srcurrent(self, indices):
227
- srcurrent = self.dataset_scanner.electric_current
228
- if srcurrent is None or len(srcurrent) == 0:
229
- return None
230
- srcurrent_all = np.array(srcurrent)
231
- if np.any(indices >= len(srcurrent_all)):
232
- self.logger.error("Something wrong with SRCurrent: not enough values!")
233
- return None
234
- return srcurrent_all[indices].astype("f")
202
+ def _get_srcurrent(self, frame_type):
203
+ # To be implemented by inheriting class
204
+ return None
205
+
206
+ @property
207
+ def projections(self):
208
+ if self._projections is None:
209
+ self._projections = self.dataset_scanner.projections
210
+ return self._projections
211
+
212
+ @projections.setter
213
+ def projections(self, val):
214
+ raise ValueError
235
215
 
236
216
  @property
237
217
  def projections_srcurrent(self):
@@ -239,8 +219,7 @@ class DatasetAnalyzer:
239
219
  Return the synchrotron electric current for each projection.
240
220
  """
241
221
  if self._projections_srcurrent is None:
242
- projections_indices = np.array(sorted(self.projections.keys()))
243
- self._projections_srcurrent = self._get_srcurrent(projections_indices)
222
+ self._projections_srcurrent = self._get_srcurrent("radios") # pylint: disable=E1128
244
223
  return self._projections_srcurrent
245
224
 
246
225
  @projections_srcurrent.setter
@@ -253,8 +232,7 @@ class DatasetAnalyzer:
253
232
  Return the synchrotron electric current for each flat image.
254
233
  """
255
234
  if self._flats_srcurrent is None:
256
- flats_indices = np.array(sorted(self.flats.keys()))
257
- self._flats_srcurrent = self._get_srcurrent(flats_indices)
235
+ self._flats_srcurrent = self._get_srcurrent("flats") # pylint: disable=E1128
258
236
  return self._flats_srcurrent
259
237
 
260
238
  @flats_srcurrent.setter
@@ -268,6 +246,32 @@ class DatasetAnalyzer:
268
246
  if getattr(self, name, None) is None:
269
247
  raise ValueError(error_msg or str("No information on %s was found in the dataset" % name))
270
248
 
249
+ @property
250
+ def flats(self):
251
+ """
252
+ Return the REDUCED flat-field images. Either by reducing (median) the raw flats, or a user-defined reduced flats.
253
+ """
254
+ if self._reduced_flats is None:
255
+ self._reduced_flats = self.get_reduced_flats()
256
+ return self._reduced_flats
257
+
258
+ @flats.setter
259
+ def flats(self, val):
260
+ self._reduced_flats = val
261
+
262
+ @property
263
+ def darks(self):
264
+ """
265
+ Return the REDUCED flat-field images. Either by reducing (mean) the raw darks, or a user-defined reduced darks.
266
+ """
267
+ if self._reduced_darks is None:
268
+ self._reduced_darks = self.get_reduced_darks()
269
+ return self._reduced_darks
270
+
271
+ @darks.setter
272
+ def darks(self, val):
273
+ self._reduced_darks = val
274
+
271
275
 
272
276
  class EDFDatasetAnalyzer(DatasetAnalyzer):
273
277
  """
@@ -278,23 +282,7 @@ class EDFDatasetAnalyzer(DatasetAnalyzer):
278
282
  kind = "edf"
279
283
 
280
284
  def _finish_init(self):
281
- self.remove_unused_radios()
282
-
283
- def remove_unused_radios(self):
284
- """
285
- Remove "unused" radios.
286
- This is used for legacy ESRF scans.
287
- """
288
- # Extraneous projections are assumed to be on the end
289
- projs_indices = sorted(self.projections.keys())
290
- used_radios_range = range(projs_indices[0], len(self.projections))
291
- radios_not_used = []
292
- for idx in self.projections.keys():
293
- if idx not in used_radios_range:
294
- radios_not_used.append(idx)
295
- for idx in radios_not_used:
296
- self.projections.pop(idx)
297
- return radios_not_used
285
+ pass
298
286
 
299
287
  def _get_flats_darks(self):
300
288
  return
@@ -311,13 +299,34 @@ class EDFDatasetAnalyzer(DatasetAnalyzer):
311
299
  return None
312
300
 
313
301
  def _get_rotation_angles(self):
314
- if self._rotation_angles is None:
315
- scan_range = self.dataset_scanner.scan_range
316
- if scan_range is not None:
317
- fullturn = abs(scan_range - 360) < abs(scan_range - 180)
318
- angles = np.linspace(0, scan_range, num=len(self.projections), endpoint=fullturn, dtype="f")
319
- self._rotation_angles = np.deg2rad(angles)
320
- return self._rotation_angles
302
+ return np.deg2rad(self.dataset_scanner.rotation_angle())
303
+
304
+ def get_reduced_flats(self, **reader_kwargs):
305
+ if self.raw_flats in [None, {}]:
306
+ raise FileNotFoundError("No reduced flat ('refHST') found in %s" % self.location)
307
+ # A few notes:
308
+ # (1) In principle we could do the reduction (mean/median) from raw frames (ref_xxxx_yyyy)
309
+ # but for legacy datasets it's always already done (by fasttomo3), and EDF support is supposed to be dropped on our side
310
+ # (2) We use EDFStackReader class to handle the possible additional data modifications
311
+ # (eg. subsampling, binning, distortion correction...)
312
+ # (3) The following spawns one reader instance per file, which is not elegant,
313
+ # but in principle there are typically 1-2 reduced flats in a scan
314
+ readers = {k: EDFStackReader([self.raw_flats[k].file_path()], **reader_kwargs) for k in self.raw_flats.keys()}
315
+ return {k: readers[k].load_data()[0] for k in self.raw_flats.keys()}
316
+
317
+ def get_reduced_darks(self, **reader_kwargs):
318
+ # See notes in get_reduced_flats() above
319
+ if self.raw_darks in [None, {}]:
320
+ raise FileNotFoundError("No reduced dark ('darkend.edf' or 'dark.edf') found in %s" % self.location)
321
+ readers = {k: EDFStackReader([self.raw_darks[k].file_path()], **reader_kwargs) for k in self.raw_darks.keys()}
322
+ return {k: readers[k].load_data()[0] for k in self.raw_darks.keys()}
323
+
324
+ @property
325
+ def files(self):
326
+ return sorted([u.file_path() for u in self.dataset_scanner.projections.values()])
327
+
328
+ def get_reader(self, **kwargs):
329
+ return EDFStackReader(self.files, **kwargs)
321
330
 
322
331
 
323
332
  class HDF5DatasetAnalyzer(DatasetAnalyzer):
@@ -326,7 +335,10 @@ class HDF5DatasetAnalyzer(DatasetAnalyzer):
326
335
  """
327
336
 
328
337
  _scanner = NXtomoScan
329
- kind = "hdf5"
338
+ kind = "nx"
339
+ # We could import the 1000+ LoC nxtomo.nxobject.nxdetector.ImageKey... or we can do this
340
+ _image_key_value = {"flats": 1, "darks": 2, "radios": 0}
341
+ #
330
342
 
331
343
  @property
332
344
  def z_translation(self):
@@ -353,10 +365,10 @@ class HDF5DatasetAnalyzer(DatasetAnalyzer):
353
365
  def _get_dataset_hdf5_url(self):
354
366
  if len(self.projections) > 0:
355
367
  frames_to_take = self.projections
356
- elif len(self.flats) > 0:
357
- frames_to_take = self.flats
358
- elif len(self.darks) > 0:
359
- frames_to_take = self.darks
368
+ elif len(self.raw_flats) > 0:
369
+ frames_to_take = self.raw_flats
370
+ elif len(self.raw_darks) > 0:
371
+ frames_to_take = self.raw_darks
360
372
  else:
361
373
  raise ValueError("No projections, no flats and no darks ?!")
362
374
  first_proj_idx = sorted(frames_to_take.keys())[0]
@@ -397,8 +409,13 @@ class HDF5DatasetAnalyzer(DatasetAnalyzer):
397
409
  slices: list of slice
398
410
  A list where each item is a slice.
399
411
  """
400
- check_supported(what, ["projections", "flats", "darks"], "image type")
401
- images = getattr(self, what) # dict
412
+ name_to_attr = {
413
+ "projections": self.projections,
414
+ "flats": self.raw_flats,
415
+ "darks": self.raw_darks,
416
+ }
417
+ check_supported(what, name_to_attr.keys(), "image type")
418
+ images = name_to_attr[what] # dict
402
419
  # we can't directly use set() on slice() object (unhashable). Use tuples
403
420
  slices = set()
404
421
  for du in get_compacted_dataslices(images).values():
@@ -410,6 +427,39 @@ class HDF5DatasetAnalyzer(DatasetAnalyzer):
410
427
  slices_list = [slice(item[0], item[1]) if item is not None else None for item in list(slices)]
411
428
  return slices_list
412
429
 
430
+ def _select_according_to_frame_type(self, data, frame_type):
431
+ if data is None:
432
+ return None
433
+ return data[self.dataset_scanner.image_key_control == self._image_key_value[frame_type]]
434
+
435
+ def get_reduced_flats(self, method="median", force_reload=False, **reader_kwargs):
436
+ dkrf_reader = NXDarksFlats(
437
+ self.dataset_hdf5_url.file_path(), data_path=self.dataset_hdf5_url.data_path(), **reader_kwargs
438
+ )
439
+ return dkrf_reader.get_reduced_flats(method=method, force_reload=force_reload, as_dict=True)
440
+
441
+ def get_reduced_darks(self, method="mean", force_reload=False, **reader_kwargs):
442
+ dkrf_reader = NXDarksFlats(
443
+ self.dataset_hdf5_url.file_path(), data_path=self.dataset_hdf5_url.data_path(), **reader_kwargs
444
+ )
445
+ return dkrf_reader.get_reduced_darks(method=method, force_reload=force_reload, as_dict=True)
446
+
447
+ def _get_srcurrent(self, frame_type):
448
+ return self._select_according_to_frame_type(self.dataset_scanner.electric_current, frame_type)
449
+
450
+ def frames_slices(self, frame_type):
451
+ """
452
+ Return a list of slice objects corresponding to the data corresponding to "frame_type".
453
+ For example, if the dataset flats are located at indices [1, 2, ..., 99], then
454
+ frame_slices("flats") will return [slice(0, 100)].
455
+ """
456
+ return indices_to_slices(
457
+ np.where(self.dataset_scanner.image_key_control == self._image_key_value[frame_type])[0]
458
+ )
459
+
460
+ def get_reader(self, **kwargs):
461
+ return NXTomoReader(self.dataset_hdf5_url.file_path(), data_path=self.dataset_hdf5_url.data_path(), **kwargs)
462
+
413
463
 
414
464
  def analyze_dataset(dataset_path, extra_options=None, logger=None):
415
465
  if not (os.path.isdir(dataset_path)):
nabu/resources/gpu.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """
2
2
  gpu.py: general-purpose utilities for GPU
3
3
  """
4
+
4
5
  from ..utils import check_supported
5
6
 
6
7
  try:
@@ -1,14 +1,16 @@
1
1
  import os
2
2
  import numpy as np
3
+ from nxtomo.io import HDF5File
3
4
  from silx.io.url import DataUrl
4
- from tomoscan.io import HDF5File
5
+ from silx.io import get_data
6
+ from tomoscan.framereducer.reducedframesinfos import ReducedFramesInfos
5
7
  from tomoscan.esrf.scan.nxtomoscan import NXtomoScan
6
8
  from ..utils import check_supported, is_writeable
7
9
 
8
10
 
9
- def get_frame_possible_urls(dataset_info, user_dir, output_dir, frame_type):
11
+ def get_frame_possible_urls(dataset_info, user_dir, output_dir):
10
12
  """
11
- Return a list with the possible location of reduced dark/flat frames.
13
+ Return a dict with the possible location of reduced dark/flat frames.
12
14
 
13
15
  Parameters
14
16
  ----------
@@ -18,19 +20,21 @@ def get_frame_possible_urls(dataset_info, user_dir, output_dir, frame_type):
18
20
  User-provided directory location for the reduced frames.
19
21
  output_dir: str or None
20
22
  Output processing directory
21
- frame_type: str
22
- Frame type, can be "flats" or "darks".
23
23
  """
24
- check_supported(frame_type, ["flats", "darks"], "frame type")
25
24
 
25
+ frame_types = ["flats", "darks"]
26
26
  h5scan = dataset_info.dataset_scanner # tomoscan object
27
- if frame_type == "flats":
28
- dataurl_default_template = h5scan.REDUCED_FLATS_DATAURLS[0]
29
- else:
30
- dataurl_default_template = h5scan.REDUCED_DARKS_DATAURLS[0]
31
27
 
32
- def make_dataurl(dirname):
33
- # The template formatting should be done by tomoscan in principle, but this complicates logging.
28
+ def make_dataurl(dirname, frame_type):
29
+ """
30
+ The template formatting should be done by tomoscan in principle, but this complicates logging.
31
+ """
32
+
33
+ if frame_type == "flats":
34
+ dataurl_default_template = h5scan.REDUCED_FLATS_DATAURLS[0]
35
+ else:
36
+ dataurl_default_template = h5scan.REDUCED_DARKS_DATAURLS[0]
37
+
34
38
  rel_file_path = dataurl_default_template.file_path().format(
35
39
  scan_prefix=dataset_info.dataset_scanner.get_dataset_basename()
36
40
  )
@@ -44,18 +48,73 @@ def get_frame_possible_urls(dataset_info, user_dir, output_dir, frame_type):
44
48
  urls = {"user": None, "dataset": None, "output": None}
45
49
 
46
50
  if user_dir is not None:
47
- urls["user"] = make_dataurl(user_dir)
51
+ urls["user"] = {frame_type: make_dataurl(user_dir, frame_type) for frame_type in frame_types}
48
52
 
49
53
  # tomoscan.esrf.scan.hdf5scan.REDUCED_{DARKS|FLATS}_DATAURLS.file_path() is a relative path
50
54
  # Create a absolute path instead
51
- urls["dataset"] = make_dataurl(os.path.dirname(h5scan.master_file))
55
+ urls["dataset"] = {
56
+ frame_type: make_dataurl(os.path.dirname(h5scan.master_file), frame_type) for frame_type in frame_types
57
+ }
52
58
 
53
59
  if output_dir is not None:
54
- urls["output"] = make_dataurl(output_dir)
60
+ urls["output"] = {frame_type: make_dataurl(output_dir, frame_type) for frame_type in frame_types}
55
61
 
56
62
  return urls
57
63
 
58
64
 
65
+ def save_reduced_frames(dataset_info, reduced_frames_arrays, reduced_frames_urls):
66
+ reduce_func = {"flats": np.median, "darks": np.mean} # TODO configurable ?
67
+
68
+ # Get "where to write". tomoscan expects a DataUrl
69
+ darks_flats_dir_url = reduced_frames_urls.get("user", None)
70
+ if darks_flats_dir_url is not None:
71
+ output_url = darks_flats_dir_url
72
+ elif is_writeable(os.path.dirname(reduced_frames_urls["dataset"]["flats"].file_path())):
73
+ output_url = reduced_frames_urls["dataset"]
74
+ else:
75
+ output_url = reduced_frames_urls["output"]
76
+
77
+ # Get the "ReducedFrameInfos" data structure expected by tomoscan
78
+ def _get_additional_info(frame_type):
79
+ electric_current = dataset_info.dataset_scanner.electric_current
80
+ count_time = dataset_info.dataset_scanner.count_time
81
+ if electric_current is not None:
82
+ electric_current = {
83
+ sl.start: reduce_func[frame_type](electric_current[sl]) for sl in dataset_info.frames_slices(frame_type)
84
+ }
85
+ electric_current = [electric_current[k] for k in sorted(electric_current.keys())]
86
+ if count_time is not None:
87
+ count_time = {
88
+ sl.start: reduce_func[frame_type](count_time[sl]) for sl in dataset_info.frames_slices(frame_type)
89
+ }
90
+ count_time = [count_time[k] for k in sorted(count_time.keys())]
91
+ info = ReducedFramesInfos()
92
+ info.count_time = count_time
93
+ info.machine_electric_current = electric_current
94
+ return info
95
+
96
+ flats_info = _get_additional_info("flats")
97
+ darks_info = _get_additional_info("darks")
98
+
99
+ # Call tomoscan to save the reduced frames
100
+ dataset_info.dataset_scanner.save_reduced_darks(
101
+ reduced_frames_arrays["darks"],
102
+ output_urls=[output_url["darks"]],
103
+ darks_infos=darks_info,
104
+ metadata_output_urls=[get_metadata_url(output_url["darks"], "darks")],
105
+ overwrite=True,
106
+ )
107
+ dataset_info.dataset_scanner.save_reduced_flats(
108
+ reduced_frames_arrays["flats"],
109
+ output_urls=[output_url["flats"]],
110
+ flats_infos=flats_info,
111
+ metadata_output_urls=[get_metadata_url(output_url["flats"], "flats")],
112
+ overwrite=True,
113
+ )
114
+ dataset_info.logger.info("Saved reduced darks/flats to %s" % output_url["flats"].file_path())
115
+ return output_url, flats_info, darks_info
116
+
117
+
59
118
  def get_metadata_url(url, frame_type):
60
119
  """
61
120
  Return the url of the metadata stored alongside flats/darks
@@ -79,12 +138,15 @@ def tomoscan_load_reduced_frames(dataset_info, frame_type, url):
79
138
  )
80
139
 
81
140
 
82
- def tomoscan_save_reduced_frames(dataset_info, frame_type, url, frames, info):
83
- tomoscan_method = getattr(dataset_info.dataset_scanner, "save_reduced_%s" % frame_type)
84
- kwargs = {"%s_infos" % frame_type: info, "overwrite": True}
85
- return tomoscan_method(
86
- frames, output_urls=[url], metadata_output_urls=[get_metadata_url(url, frame_type)], **kwargs
87
- )
141
+ def data_url_exists(data_url):
142
+ """
143
+ Return true iff the file exists and the data URL is valid (i.e data/group is actually in the file)
144
+ """
145
+ if not (os.path.isfile(data_url.file_path())):
146
+ return False
147
+ with HDF5File(data_url.file_path(), "r") as f:
148
+ path_exists = f.get(data_url.data_path(), default=None) is not None
149
+ return path_exists
88
150
 
89
151
 
90
152
  # pylint: disable=E1136
@@ -107,111 +169,58 @@ def update_dataset_info_flats_darks(dataset_info, flatfield_mode, output_dir=Non
107
169
  """
108
170
  if flatfield_mode is False:
109
171
  return
110
- logger = dataset_info.logger
172
+
111
173
  frames_types = ["darks", "flats"]
174
+ reduced_frames_urls = get_frame_possible_urls(dataset_info, darks_flats_dir, output_dir)
175
+
176
+ def _compute_and_save_reduced_frames():
177
+ try:
178
+ dataset_info.flats = dataset_info.get_reduced_flats()
179
+ dataset_info.darks = dataset_info.get_reduced_darks()
180
+ except FileNotFoundError:
181
+ msg = "Could not find any flats and/or darks"
182
+ raise FileNotFoundError(msg)
183
+ _, flats_info, darks_info = save_reduced_frames(
184
+ dataset_info, {"darks": dataset_info.darks, "flats": dataset_info.flats}, reduced_frames_urls
185
+ )
186
+ dataset_info.flats_srcurrent = flats_info.machine_electric_current
112
187
 
113
- reduced_frames_urls = {}
114
- for frame_type in frames_types:
115
- reduced_frames_urls[frame_type] = get_frame_possible_urls(dataset_info, darks_flats_dir, output_dir, frame_type)
116
-
117
- reduced_frames = dict.fromkeys(frames_types, None)
118
-
119
- #
120
- # Try to load frames
121
- #
122
- def load_reduced_frame(url, frame_type, frames_loaded, reduced_frames):
123
- if frames_loaded[frame_type]:
124
- return
125
- frames, info = tomoscan_load_reduced_frames(dataset_info, frame_type, url)
126
- if frames not in (None, {}):
127
- dataset_info.logger.info("Loaded %s from %s" % (frame_type, url.file_path()))
128
- frames_loaded[frame_type] = True
129
- reduced_frames[frame_type] = frames, info
130
- else:
131
- msg = "Could not load %s from %s" % (frame_type, url.file_path())
132
- logger.error(msg)
133
-
134
- frames_loaded = dict.fromkeys(frames_types, False)
135
- if flatfield_mode != "force-compute":
136
- for load_from in ["user", "dataset", "output"]: # in that order
137
- for frame_type in frames_types:
138
- url = reduced_frames_urls[frame_type][load_from]
139
- if url is None:
140
- continue # cannot load from this source (eg. undefined folder)
141
- load_reduced_frame(url, frame_type, frames_loaded, reduced_frames)
142
- if all(frames_loaded.values()):
143
- break
144
-
145
- if not all(frames_loaded.values()) and flatfield_mode == "force-load":
146
- raise ValueError("Could not load darks/flats (using 'force-load')")
188
+ if flatfield_mode == "force-compute":
189
+ _compute_and_save_reduced_frames()
190
+ return
147
191
 
148
- #
149
- # COMPAT. Keep DataUrl - won't be needed in future versions when pipeline will use FlatField
150
- # instead of FlatFieldDataUrl
151
- frames_urls = reduced_frames.copy()
152
- #
153
-
154
- # Compute reduced frames, if needed
155
- #
156
- if reduced_frames["flats"] is None:
157
- reduced_frames["flats"] = dataset_info.dataset_scanner.compute_reduced_flats(return_info=True)
158
- if reduced_frames["darks"] is None:
159
- reduced_frames["darks"] = dataset_info.dataset_scanner.compute_reduced_darks(return_info=True)
160
-
161
- if reduced_frames["darks"][0] == {} or reduced_frames["flats"][0] == {}:
162
- raise ValueError(
163
- "Could not get any reduced flat/dark. This probably means that no already-reduced flats/darks were found and that the dataset itself does not have any flat/dark"
164
- )
192
+ def _can_load_from(folder_type):
193
+ if reduced_frames_urls.get(folder_type, None) is None:
194
+ return False
195
+ return all([data_url_exists(reduced_frames_urls[folder_type][frame_type]) for frame_type in frames_types])
196
+
197
+ where_to_load_from = None
198
+ if reduced_frames_urls["user"] is not None and _can_load_from("user"):
199
+ where_to_load_from = "user"
200
+ elif _can_load_from("dataset"):
201
+ where_to_load_from = "dataset"
202
+ elif _can_load_from("output"):
203
+ where_to_load_from = "output"
165
204
 
166
- #
167
- # Save reduced frames
168
- #
169
-
170
- def save_reduced_frame(url, frame_type, frames_saved):
171
- frames, info = reduced_frames[frame_type]
172
- tomoscan_save_reduced_frames(dataset_info, frame_type, url, frames, info)
173
- dataset_info.logger.info("Saved reduced %s to %s" % (frame_type, url.file_path()))
174
- frames_saved[frame_type] = True
175
-
176
- frames_saved = dict.fromkeys(frames_types, False)
177
- if not all(frames_loaded.values()):
178
- for save_to in ["user", "dataset", "output"]: # in that order
179
- for frame_type in frames_types:
180
- if frames_loaded[frame_type]:
181
- continue # already loaded
182
- url = reduced_frames_urls[frame_type][save_to]
183
- if url is None:
184
- continue # cannot load from this source (eg. undefined folder)
185
- if not is_writeable(os.path.dirname(url.file_path())):
186
- continue
187
- save_reduced_frame(url, frame_type, frames_saved)
188
- # COMPAT.
189
- if frames_urls[frame_type] is None:
190
- frames_urls[frame_type] = tomoscan_load_reduced_frames(dataset_info, frame_type, url)
191
- #
192
- if all(frames_saved.values()):
193
- break
194
-
195
- dataset_info.flats = frames_urls["flats"][0] # reduced_frames["flats"] # in future versions
196
- dataset_info.flats_srcurrent = frames_urls["flats"][1].machine_electric_current
197
- # This is an extra check to avoid having more than 1 (reduced) dark.
198
- # FlatField only works with exactly 1 (reduced) dark (having more than 1 series of darks makes little sense)
199
- # This is normally prevented by tomoscan HDF5FramesReducer, but let's add this extra check
200
- darks_ = frames_urls["darks"][0] # reduced_frames["darks"] # in future versions
201
- if len(darks_) > 1:
202
- dark_idx = sorted(darks_.keys())[0]
203
- dataset_info.logger.error("Found more that one series of darks. Keeping only the first one")
204
- darks_ = {dark_idx: darks_[dark_idx]}
205
- #
206
- dataset_info.darks = darks_
207
- dataset_info.darks_srcurrent = frames_urls["darks"][1].machine_electric_current
208
-
209
-
210
- # tomoscan "compute_reduced_XX" is quite slow. If needed, here is an alternative implementation
211
- def my_reduce_flats(di):
212
- res = {}
213
- with HDF5File(di.dataset_hdf5_url.file_path(), "r") as f:
214
- for data_slice in di.get_data_slices("flats"):
215
- data = f[di.dataset_hdf5_url.data_path()][data_slice.start : data_slice.stop]
216
- res[data_slice.start] = np.median(data, axis=0)
217
- return res
205
+ if where_to_load_from == None and flatfield_mode == "force-load":
206
+ raise ValueError("Could not load darks/flats (using 'force-load')")
207
+
208
+ if where_to_load_from is not None:
209
+ reduced_frames_with_info = {}
210
+ for frame_type in frames_types:
211
+ reduced_frames_with_info[frame_type] = tomoscan_load_reduced_frames(
212
+ dataset_info, frame_type, reduced_frames_urls[where_to_load_from][frame_type]
213
+ )
214
+ dataset_info.logger.info(
215
+ "Loaded %s from %s" % (frame_type, reduced_frames_urls[where_to_load_from][frame_type].file_path())
216
+ )
217
+ red_frames_dict, red_frames_info = reduced_frames_with_info[frame_type]
218
+ setattr(
219
+ dataset_info,
220
+ frame_type,
221
+ {k: get_data(red_frames_dict[k]) for k in red_frames_dict.keys()},
222
+ )
223
+ if frame_type == "flats":
224
+ dataset_info.flats_srcurrent = red_frames_info.machine_electric_current
225
+ else:
226
+ _compute_and_save_reduced_frames()