nabu 2023.2.1__py3-none-any.whl → 2024.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (183) hide show
  1. doc/conf.py +1 -1
  2. doc/doc_config.py +32 -0
  3. nabu/__init__.py +2 -1
  4. nabu/app/bootstrap_stitching.py +1 -1
  5. nabu/app/cli_configs.py +122 -2
  6. nabu/app/composite_cor.py +27 -2
  7. nabu/app/correct_rot.py +70 -0
  8. nabu/app/create_distortion_map_from_poly.py +42 -18
  9. nabu/app/diag_to_pix.py +358 -0
  10. nabu/app/diag_to_rot.py +449 -0
  11. nabu/app/generate_header.py +4 -3
  12. nabu/app/histogram.py +2 -2
  13. nabu/app/multicor.py +6 -1
  14. nabu/app/parse_reconstruction_log.py +151 -0
  15. nabu/app/prepare_weights_double.py +83 -22
  16. nabu/app/reconstruct.py +5 -1
  17. nabu/app/reconstruct_helical.py +7 -0
  18. nabu/app/reduce_dark_flat.py +6 -3
  19. nabu/app/rotate.py +4 -4
  20. nabu/app/stitching.py +16 -2
  21. nabu/app/tests/test_reduce_dark_flat.py +18 -2
  22. nabu/app/validator.py +4 -4
  23. nabu/cuda/convolution.py +8 -376
  24. nabu/cuda/fft.py +4 -0
  25. nabu/cuda/kernel.py +4 -4
  26. nabu/cuda/medfilt.py +5 -158
  27. nabu/cuda/padding.py +5 -71
  28. nabu/cuda/processing.py +23 -2
  29. nabu/cuda/src/ElementOp.cu +78 -0
  30. nabu/cuda/src/backproj.cu +28 -2
  31. nabu/cuda/src/fourier_wavelets.cu +2 -2
  32. nabu/cuda/src/normalization.cu +23 -0
  33. nabu/cuda/src/padding.cu +2 -2
  34. nabu/cuda/src/transpose.cu +16 -0
  35. nabu/cuda/utils.py +39 -0
  36. nabu/estimation/alignment.py +10 -1
  37. nabu/estimation/cor.py +808 -38
  38. nabu/estimation/cor_sino.py +7 -9
  39. nabu/estimation/tests/test_cor.py +85 -3
  40. nabu/io/reader.py +26 -18
  41. nabu/io/tests/test_cast_volume.py +3 -3
  42. nabu/io/tests/test_detector_distortion.py +3 -3
  43. nabu/io/tiffwriter_zmm.py +2 -2
  44. nabu/io/utils.py +14 -4
  45. nabu/io/writer.py +5 -3
  46. nabu/misc/fftshift.py +6 -0
  47. nabu/misc/histogram.py +5 -285
  48. nabu/misc/histogram_cuda.py +8 -104
  49. nabu/misc/kernel_base.py +3 -121
  50. nabu/misc/padding_base.py +5 -69
  51. nabu/misc/processing_base.py +3 -107
  52. nabu/misc/rotation.py +5 -62
  53. nabu/misc/rotation_cuda.py +5 -65
  54. nabu/misc/transpose.py +6 -0
  55. nabu/misc/unsharp.py +3 -78
  56. nabu/misc/unsharp_cuda.py +5 -52
  57. nabu/misc/unsharp_opencl.py +8 -85
  58. nabu/opencl/fft.py +6 -0
  59. nabu/opencl/kernel.py +21 -6
  60. nabu/opencl/padding.py +5 -72
  61. nabu/opencl/processing.py +27 -5
  62. nabu/opencl/src/backproj.cl +3 -3
  63. nabu/opencl/src/fftshift.cl +65 -12
  64. nabu/opencl/src/padding.cl +2 -2
  65. nabu/opencl/src/roll.cl +96 -0
  66. nabu/opencl/src/transpose.cl +16 -0
  67. nabu/pipeline/config_validators.py +63 -3
  68. nabu/pipeline/dataset_validator.py +2 -2
  69. nabu/pipeline/estimators.py +193 -35
  70. nabu/pipeline/fullfield/chunked.py +34 -17
  71. nabu/pipeline/fullfield/chunked_cuda.py +7 -5
  72. nabu/pipeline/fullfield/computations.py +48 -13
  73. nabu/pipeline/fullfield/nabu_config.py +13 -13
  74. nabu/pipeline/fullfield/processconfig.py +10 -5
  75. nabu/pipeline/fullfield/reconstruction.py +1 -2
  76. nabu/pipeline/helical/fbp.py +5 -0
  77. nabu/pipeline/helical/filtering.py +12 -9
  78. nabu/pipeline/helical/gridded_accumulator.py +179 -33
  79. nabu/pipeline/helical/helical_chunked_regridded.py +262 -151
  80. nabu/pipeline/helical/helical_chunked_regridded_cuda.py +4 -11
  81. nabu/pipeline/helical/helical_reconstruction.py +56 -18
  82. nabu/pipeline/helical/span_strategy.py +1 -1
  83. nabu/pipeline/helical/tests/test_accumulator.py +4 -0
  84. nabu/pipeline/params.py +23 -2
  85. nabu/pipeline/processconfig.py +3 -8
  86. nabu/pipeline/tests/test_chunk_reader.py +78 -0
  87. nabu/pipeline/tests/test_estimators.py +120 -2
  88. nabu/pipeline/utils.py +25 -0
  89. nabu/pipeline/writer.py +2 -0
  90. nabu/preproc/ccd_cuda.py +9 -7
  91. nabu/preproc/ctf.py +21 -26
  92. nabu/preproc/ctf_cuda.py +25 -25
  93. nabu/preproc/double_flatfield.py +14 -2
  94. nabu/preproc/double_flatfield_cuda.py +7 -11
  95. nabu/preproc/flatfield_cuda.py +23 -27
  96. nabu/preproc/phase.py +19 -24
  97. nabu/preproc/phase_cuda.py +21 -21
  98. nabu/preproc/shift_cuda.py +58 -28
  99. nabu/preproc/tests/test_ctf.py +5 -5
  100. nabu/preproc/tests/test_double_flatfield.py +2 -2
  101. nabu/preproc/tests/test_vshift.py +13 -2
  102. nabu/processing/__init__.py +0 -0
  103. nabu/processing/convolution_cuda.py +375 -0
  104. nabu/processing/fft_base.py +163 -0
  105. nabu/processing/fft_cuda.py +256 -0
  106. nabu/processing/fft_opencl.py +54 -0
  107. nabu/processing/fftshift.py +134 -0
  108. nabu/processing/histogram.py +286 -0
  109. nabu/processing/histogram_cuda.py +103 -0
  110. nabu/processing/kernel_base.py +126 -0
  111. nabu/processing/medfilt_cuda.py +159 -0
  112. nabu/processing/muladd.py +29 -0
  113. nabu/processing/muladd_cuda.py +68 -0
  114. nabu/processing/padding_base.py +71 -0
  115. nabu/processing/padding_cuda.py +75 -0
  116. nabu/processing/padding_opencl.py +77 -0
  117. nabu/processing/processing_base.py +123 -0
  118. nabu/processing/roll_opencl.py +64 -0
  119. nabu/processing/rotation.py +63 -0
  120. nabu/processing/rotation_cuda.py +66 -0
  121. nabu/processing/tests/__init__.py +0 -0
  122. nabu/processing/tests/test_fft.py +268 -0
  123. nabu/processing/tests/test_fftshift.py +71 -0
  124. nabu/{misc → processing}/tests/test_histogram.py +2 -4
  125. nabu/{cuda → processing}/tests/test_medfilt.py +1 -1
  126. nabu/processing/tests/test_muladd.py +54 -0
  127. nabu/{cuda → processing}/tests/test_padding.py +119 -75
  128. nabu/processing/tests/test_roll.py +63 -0
  129. nabu/{misc → processing}/tests/test_rotation.py +3 -2
  130. nabu/processing/tests/test_transpose.py +72 -0
  131. nabu/{misc → processing}/tests/test_unsharp.py +41 -8
  132. nabu/processing/transpose.py +126 -0
  133. nabu/processing/unsharp.py +79 -0
  134. nabu/processing/unsharp_cuda.py +53 -0
  135. nabu/processing/unsharp_opencl.py +75 -0
  136. nabu/reconstruction/fbp.py +34 -10
  137. nabu/reconstruction/fbp_base.py +35 -16
  138. nabu/reconstruction/fbp_opencl.py +7 -12
  139. nabu/reconstruction/filtering.py +2 -2
  140. nabu/reconstruction/filtering_cuda.py +13 -14
  141. nabu/reconstruction/filtering_opencl.py +3 -4
  142. nabu/reconstruction/projection.py +2 -0
  143. nabu/reconstruction/rings.py +158 -1
  144. nabu/reconstruction/rings_cuda.py +218 -58
  145. nabu/reconstruction/sinogram_cuda.py +16 -12
  146. nabu/reconstruction/tests/test_deringer.py +116 -14
  147. nabu/reconstruction/tests/test_fbp.py +22 -31
  148. nabu/reconstruction/tests/test_filtering.py +11 -2
  149. nabu/resources/dataset_analyzer.py +89 -26
  150. nabu/resources/nxflatfield.py +2 -2
  151. nabu/resources/tests/test_nxflatfield.py +1 -1
  152. nabu/resources/utils.py +9 -2
  153. nabu/stitching/alignment.py +184 -0
  154. nabu/stitching/config.py +241 -39
  155. nabu/stitching/definitions.py +6 -0
  156. nabu/stitching/frame_composition.py +4 -2
  157. nabu/stitching/overlap.py +99 -3
  158. nabu/stitching/sample_normalization.py +60 -0
  159. nabu/stitching/slurm_utils.py +10 -10
  160. nabu/stitching/tests/test_alignment.py +99 -0
  161. nabu/stitching/tests/test_config.py +16 -1
  162. nabu/stitching/tests/test_overlap.py +68 -2
  163. nabu/stitching/tests/test_sample_normalization.py +49 -0
  164. nabu/stitching/tests/test_slurm_utils.py +5 -5
  165. nabu/stitching/tests/test_utils.py +3 -33
  166. nabu/stitching/tests/test_z_stitching.py +391 -22
  167. nabu/stitching/utils.py +144 -202
  168. nabu/stitching/z_stitching.py +309 -126
  169. nabu/testutils.py +18 -0
  170. nabu/thirdparty/tomocupy_remove_stripe.py +586 -0
  171. nabu/utils.py +32 -6
  172. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/LICENSE +1 -1
  173. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/METADATA +5 -5
  174. nabu-2024.1.0rc3.dist-info/RECORD +296 -0
  175. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/WHEEL +1 -1
  176. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/entry_points.txt +5 -1
  177. nabu/conftest.py +0 -14
  178. nabu/opencl/fftshift.py +0 -92
  179. nabu/opencl/tests/test_fftshift.py +0 -55
  180. nabu/opencl/tests/test_padding.py +0 -84
  181. nabu-2023.2.1.dist-info/RECORD +0 -252
  182. /nabu/cuda/src/{fftshift.cu → dfi_fftshift.cu} +0 -0
  183. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/top_level.txt +0 -0
nabu/stitching/config.py CHANGED
@@ -28,36 +28,27 @@ __license__ = "MIT"
28
28
  __date__ = "10/05/2022"
29
29
 
30
30
 
31
+ from math import ceil
32
+ from typing import Any, Iterable, Optional, Union, Sized
31
33
  from dataclasses import dataclass
32
34
  import numpy
35
+ from pyunitsystem.metricsystem import MetricSystem
36
+ from nxtomo.paths import nxtomo
37
+ from tomoscan.factory import Factory
33
38
  from tomoscan.identifier import VolumeIdentifier, ScanIdentifier
34
- from tomoscan.esrf import HDF5TomoScan
35
- from tomoscan.nexus.paths import nxtomo
36
- from silx.utils.enum import Enum as _Enum
37
- from typing import Optional, Union, Sized
38
- from nabu.pipeline.config_validators import (
39
- integer_validator,
40
- list_of_shift_validator,
41
- list_of_tomoscan_identifier,
42
- optional_directory_location_validator,
39
+ from tomoscan.esrf.scan.nxtomoscan import NXtomoScan
40
+ from ..pipeline.config_validators import (
43
41
  boolean_validator,
44
42
  convert_to_bool,
45
- optional_positive_integer_validator,
46
- output_file_format_validator,
47
- optional_tuple_of_floats_validator,
48
- optional_file_name_validator,
49
43
  )
50
- from nabu.stitching.overlap import OverlapStitchingStrategy
51
- from nabu.utils import concatenate_dict, convert_str_to_tuple
52
- from nabu.io.utils import get_output_volume
53
- from tomoscan.factory import Factory
54
- from typing import Iterable
55
- from nabu.stitching.utils import ShiftAlgorithm
56
- from tomoscan.unitsystem.metricsystem import MetricSystem
57
- from math import ceil
58
-
44
+ from ..utils import concatenate_dict, convert_str_to_tuple
45
+ from ..io.utils import get_output_volume
46
+ from .overlap import OverlapStitchingStrategy
47
+ from .utils import ShiftAlgorithm
48
+ from .definitions import StitchingType
49
+ from .alignment import AlignmentAxis1, AlignmentAxis2
50
+ from pyunitsystem.metricsystem import MetricSystem
59
51
 
60
- KEY_SCORE_METHOD = "score_method"
61
52
 
62
53
  KEY_IMG_REG_METHOD = "img_reg_method"
63
54
 
@@ -137,6 +128,12 @@ KEY_RESCALE_MIN_PERCENTILES = "rescale_min_percentile"
137
128
 
138
129
  KEY_RESCALE_MAX_PERCENTILES = "rescale_max_percentile"
139
130
 
131
+ ALIGNMENT_AXIS_2_FIELD = "alignment_axis_2"
132
+
133
+ ALIGNMENT_AXIS_1_FIELD = "alignment_axis_1"
134
+
135
+ PAD_MODE_FIELD = "pad_mode"
136
+
140
137
  # SLURM
141
138
 
142
139
  SLURM_SECTION = "slurm"
@@ -155,8 +152,24 @@ SLURM_OTHER_OPTIONS = "other_options"
155
152
 
156
153
  SLURM_PREPROCESSING_COMMAND = "python_venv"
157
154
 
155
+ SLURM_MODULES_TO_LOADS = "modules"
156
+
158
157
  SLURM_CLEAN_SCRIPTS = "clean_scripts"
159
158
 
159
+ # normalization by sample
160
+
161
+ NORMALIZATION_BY_SAMPLE_SECTION = "normalization_by_sample"
162
+
163
+ NORMALIZATION_BY_SAMPLE_ACTIVE_FIELD = "active"
164
+
165
+ NORMALIZATION_BY_SAMPLE_METHOD = "method"
166
+
167
+ NORMALIZATION_BY_SAMPLE_SIDE = "side"
168
+
169
+ NORMALIZATION_BY_SAMPLE_MARGIN = "margin"
170
+
171
+ NORMALIZATION_BY_SAMPLE_WIDTH = "width"
172
+
160
173
  # kernel extra options
161
174
 
162
175
  STITCHING_KERNELS_EXTRA_PARAMS = "stitching_kernels_extra_params"
@@ -168,7 +181,6 @@ CROSS_CORRELATION_METHODS_AXIS_0 = {
168
181
  "": "", # for display
169
182
  ShiftAlgorithm.NABU_FFT.value: "will call nabu `find_shift_correlate` function - shift search in fourier space",
170
183
  ShiftAlgorithm.SKIMAGE.value: "use scikit image `phase_cross_correlation` function in real space",
171
- ShiftAlgorithm.SHIFT_GRID.value: "will compute a score for each possible shift and pick the shift with the highest score",
172
184
  ShiftAlgorithm.NONE.value: "no shift research is done. will only get shift from motor positions",
173
185
  }
174
186
 
@@ -189,6 +201,7 @@ SECTIONS_COMMENTS = {
189
201
  OUTPUT_SECTION: "section dedicated to output parameters\n",
190
202
  INPUTS_SECTION: "section dedicated to inputs\n",
191
203
  SLURM_SECTION: "section didicated to slurm. If you want to run locally avoid setting 'partition or remove this section'",
204
+ NORMALIZATION_BY_SAMPLE_SECTION: "section dedicated to normalization by a sample. If activate each frame can be normalized by a sample of the frame",
192
205
  }
193
206
 
194
207
  DEFAULT_SHIFT_ALG_AXIS_0 = "nabu-fft"
@@ -202,9 +215,7 @@ _shift_algs_axis_2 = "\n + ".join(
202
215
  )
203
216
 
204
217
  HELP_SHIFT_PARAMS = f"""options for shifts algorithms as `key1=value1,key2=value2`. For now valid keys are:
205
- - {KEY_WINDOW_SIZE}: size of the window for the 'shift-grid' algorithm'.
206
218
  - {KEY_OVERLAP_SIZE}: size to apply stitching. If not provided will take the largest size possible'.
207
- - {KEY_SCORE_METHOD}: method to use in order to compute score for the 'shift-grid' algorithm. Values can be 'tv' (total variation), '1/tv', 'std' (standard deviation), '1/std'.
208
219
  - {KEY_IMG_REG_METHOD}: algorithm to use to find overlaps between the different sections. Possible values are \n * for axis 0: {_shift_algs_axis_0}\n * and for axis 2: {_shift_algs_axis_2}
209
220
  - {KEY_LOW_PASS_FILTER}: low pass filter value for filtering frames before shift research
210
221
  - {KEY_HIGH_PASS_FILTER}: high pass filter value for filtering frames before shift research"""
@@ -263,7 +274,6 @@ def _valid_shifts_params(my_dict: Union[dict, str]):
263
274
  my_dict = _str_to_dict(my_str=my_dict)
264
275
 
265
276
  valid_keys = (
266
- KEY_SCORE_METHOD,
267
277
  KEY_WINDOW_SIZE,
268
278
  KEY_IMG_REG_METHOD,
269
279
  KEY_OVERLAP_SIZE,
@@ -326,6 +336,106 @@ def _scalar_or_tuple_to_bool_or_tuple_of_bool(my_str: Union[bool, tuple, str], d
326
336
  return values
327
337
 
328
338
 
339
+ from nabu.stitching.sample_normalization import Method, SampleSide
340
+
341
+
342
+ class NormalizationBySample:
343
+ def __init__(self) -> None:
344
+ self._active = False
345
+ self._method = Method.MEAN
346
+ self._margin = 0
347
+ self._side = SampleSide.LEFT
348
+ self._width = 30
349
+
350
+ def is_active(self):
351
+ return self._active
352
+
353
+ def set_is_active(self, active: bool):
354
+ assert isinstance(
355
+ active, bool
356
+ ), f"active is expected to be a bool. Get {type(active)} instead. Value == {active}"
357
+ self._active = active
358
+
359
+ @property
360
+ def method(self) -> Method:
361
+ return self._method
362
+
363
+ @method.setter
364
+ def method(self, method: Union[Method, str]) -> None:
365
+ self._method = Method.from_value(method)
366
+
367
+ @property
368
+ def margin(self) -> int:
369
+ return self._margin
370
+
371
+ @margin.setter
372
+ def margin(self, margin: int):
373
+ assert isinstance(margin, int), f"margin is expected to be an int. Get {type(margin)} instead"
374
+ self._margin = margin
375
+
376
+ @property
377
+ def side(self) -> SampleSide:
378
+ return self._side
379
+
380
+ @side.setter
381
+ def side(self, side: Union[SampleSide, str]):
382
+ self._side = SampleSide.from_value(side)
383
+
384
+ @property
385
+ def width(self) -> int:
386
+ return self._width
387
+
388
+ @width.setter
389
+ def width(self, width: int):
390
+ assert isinstance(width, int), f"width is expected to be an int. Get {type(width)} instead"
391
+
392
+ @staticmethod
393
+ def from_dict(my_dict: dict):
394
+ sample_normalization = NormalizationBySample()
395
+ # active
396
+ active = my_dict.get(NORMALIZATION_BY_SAMPLE_ACTIVE_FIELD, None)
397
+ if active is not None:
398
+ active = active in (True, "True", 1, "1")
399
+ sample_normalization.set_is_active(active)
400
+
401
+ # method
402
+ method = my_dict.get(NORMALIZATION_BY_SAMPLE_METHOD, None)
403
+ if method is not None:
404
+ sample_normalization.method = method
405
+
406
+ # margin
407
+ margin = my_dict.get(NORMALIZATION_BY_SAMPLE_MARGIN, None)
408
+ if margin is not None:
409
+ sample_normalization.margin = int(margin)
410
+
411
+ # side
412
+ side = my_dict.get(NORMALIZATION_BY_SAMPLE_SIDE, None)
413
+ if side is not None:
414
+ sample_normalization.side = side
415
+
416
+ # width
417
+ width = my_dict.get(NORMALIZATION_BY_SAMPLE_WIDTH, None)
418
+ if width is not None:
419
+ sample_normalization.width = int(width)
420
+
421
+ return sample_normalization
422
+
423
+ def to_dict(self) -> dict:
424
+ return {
425
+ NORMALIZATION_BY_SAMPLE_ACTIVE_FIELD: self.is_active(),
426
+ NORMALIZATION_BY_SAMPLE_METHOD: self.method.value,
427
+ NORMALIZATION_BY_SAMPLE_MARGIN: self.margin,
428
+ NORMALIZATION_BY_SAMPLE_SIDE: self.side.value,
429
+ NORMALIZATION_BY_SAMPLE_WIDTH: self.width,
430
+ }
431
+
432
+ def __eq__(self, __value: object) -> bool:
433
+ if not isinstance(__value, NormalizationBySample):
434
+ return False
435
+ else:
436
+ return self.to_dict() == __value.to_dict()
437
+
438
+
329
439
  @dataclass
330
440
  class SlurmConfig:
331
441
  "configuration for slurm jobs"
@@ -334,10 +444,18 @@ class SlurmConfig:
334
444
  n_jobs: int = 1
335
445
  other_options: str = ""
336
446
  preprocessing_command: str = ""
447
+ modules_to_load: tuple = tuple()
337
448
  clean_script: bool = ""
338
449
  n_tasks: int = 1
339
450
  n_cpu_per_task: int = 4
340
451
 
452
+ def __post_init__(self) -> None:
453
+ # make sure either 'modules' or 'preprocessing_command' is provided
454
+ if len(self.modules_to_load) > 0 and self.preprocessing_command not in (None, ""):
455
+ raise ValueError(
456
+ f"Either modules {SLURM_MODULES_TO_LOADS} or preprocessing_command {SLURM_PREPROCESSING_COMMAND} can be provided. Not both."
457
+ )
458
+
341
459
  def to_dict(self) -> dict:
342
460
  "dump configuration to dict"
343
461
  return {
@@ -346,6 +464,7 @@ class SlurmConfig:
346
464
  SLURM_N_JOBS: self.n_jobs,
347
465
  SLURM_OTHER_OPTIONS: self.other_options,
348
466
  SLURM_PREPROCESSING_COMMAND: self.preprocessing_command,
467
+ SLURM_MODULES_TO_LOADS: self.modules_to_load,
349
468
  SLURM_CLEAN_SCRIPTS: self.clean_script,
350
469
  SLURM_NUMBER_OF_TASKS: self.n_tasks,
351
470
  SLURM_COR_PER_TASKS: self.n_cpu_per_task,
@@ -363,15 +482,11 @@ class SlurmConfig:
363
482
  n_tasks=config.get(SLURM_NUMBER_OF_TASKS, 1),
364
483
  n_cpu_per_task=config.get(SLURM_COR_PER_TASKS, 4),
365
484
  preprocessing_command=config.get(SLURM_PREPROCESSING_COMMAND, ""),
485
+ modules_to_load=convert_str_to_tuple(config.get(SLURM_MODULES_TO_LOADS, "")),
366
486
  clean_script=convert_to_bool(config.get(SLURM_CLEAN_SCRIPTS, False))[0],
367
487
  )
368
488
 
369
489
 
370
- class StitchingType(_Enum):
371
- Z_PREPROC = "z-preproc"
372
- Z_POSTPROC = "z-postproc"
373
-
374
-
375
490
  def _cast_shift_to_str(shifts: Union[tuple, str, None]) -> str:
376
491
  if shifts is None:
377
492
  return ""
@@ -408,9 +523,9 @@ class StitchingConfiguration:
408
523
  axis_2_params: dict = None
409
524
  slurm_config: SlurmConfig = None
410
525
  flip_lr: Union[tuple, bool] = False
411
- "flip frame left-right. For scan this will happen after possible flip of NXtomo metadata x_flipped field (also know as lr_flipped)"
526
+ "flip frame left-right. For scan this will be append to the NXtransformations of the detector"
412
527
  flip_ud: Union[tuple, bool] = False
413
- "flip frame up-down. For scan this will happen after possible flip of NXtomo metadata y_flipped field (also know as ud_flipped)"
528
+ "flip frame up-down. For scan this will be append to the NXtransformations of the detector"
414
529
 
415
530
  overwrite_results: bool = False
416
531
  stitching_strategy: OverlapStitchingStrategy = OverlapStitchingStrategy.COSINUS_WEIGHTS
@@ -422,10 +537,16 @@ class StitchingConfiguration:
422
537
  rescale_frames: bool = False
423
538
  rescale_params: dict = None
424
539
 
540
+ normalization_by_sample: NormalizationBySample = None
541
+
425
542
  @property
426
543
  def stitching_type(self):
427
544
  raise NotImplementedError("Base class")
428
545
 
546
+ def __post_init__(self):
547
+ if self.normalization_by_sample is None:
548
+ self.normalization_by_sample = NormalizationBySample()
549
+
429
550
  @staticmethod
430
551
  def get_description_dict() -> dict:
431
552
  def get_pos_info(axis, unit, alternative):
@@ -526,6 +647,16 @@ class StitchingConfiguration:
526
647
  "help": f"advanced parameters for some stitching kernels. must be provided as 'key1=value1;key_2=value2'. Valid keys for now are: {KEY_THRESHOLD_FREQUENCY}: threshold to be used by the {OverlapStitchingStrategy.IMAGE_MINIMUM_DIVERGENCE.value} to split images low and high frequencies in Fourier space.",
527
648
  "type": "advanced",
528
649
  },
650
+ ALIGNMENT_AXIS_2_FIELD: {
651
+ "default": "center",
652
+ "help": f"In case frame have different frame widths how to align them (so along volume axis 2). Valid keys are {AlignmentAxis2.values()}",
653
+ "type": "advanced",
654
+ },
655
+ PAD_MODE_FIELD: {
656
+ "default": "constant",
657
+ "help": f"pad mode to use for frame alignment. Valid values are 'constant', 'edge', 'linear_ramp', maximum', 'mean', 'median', 'minimum', 'reflect', 'symmetric', 'wrap', and 'empty'. See nupy.pad documentation for details",
658
+ "type": "advanced",
659
+ },
529
660
  },
530
661
  OUTPUT_SECTION: {
531
662
  OVERWRITE_RESULTS_FIELD: {
@@ -583,6 +714,38 @@ class StitchingConfiguration:
583
714
  "help": "python virtual environment to use",
584
715
  "type": "optional",
585
716
  },
717
+ SLURM_MODULES_TO_LOADS: {
718
+ "default": "",
719
+ "help": "module to load",
720
+ "type": "optional",
721
+ },
722
+ },
723
+ NORMALIZATION_BY_SAMPLE_SECTION: {
724
+ NORMALIZATION_BY_SAMPLE_ACTIVE_FIELD: {
725
+ "default": False,
726
+ "help": "should we apply frame normalization by a sample or not",
727
+ "type": "advanced",
728
+ },
729
+ NORMALIZATION_BY_SAMPLE_METHOD: {
730
+ "default": "median",
731
+ "help": "method to compute the normalization value",
732
+ "type": "advanced",
733
+ },
734
+ NORMALIZATION_BY_SAMPLE_SIDE: {
735
+ "default": "left",
736
+ "help": "side to pick the sample",
737
+ "type": "advanced",
738
+ },
739
+ NORMALIZATION_BY_SAMPLE_MARGIN: {
740
+ "default": 0,
741
+ "help": "margin (in px) between border and sample",
742
+ "type": "advanced",
743
+ },
744
+ NORMALIZATION_BY_SAMPLE_WIDTH: {
745
+ "default": 30,
746
+ "help": "sample width (in px)",
747
+ "type": "advanced",
748
+ },
586
749
  },
587
750
  }
588
751
 
@@ -614,6 +777,7 @@ class StitchingConfiguration:
614
777
  self.overwrite_results,
615
778
  ),
616
779
  },
780
+ NORMALIZATION_BY_SAMPLE_SECTION: self.normalization_by_sample.to_dict(),
617
781
  }
618
782
 
619
783
 
@@ -627,6 +791,10 @@ class ZStitchingConfiguration(StitchingConfiguration):
627
791
  slice, tuple, None
628
792
  ] = None # slices to reconstruct. Over axis 0 for pre-processing, over axis 1 for post-processing. If None will reconstruct all
629
793
 
794
+ alignment_axis_2: AlignmentAxis2 = AlignmentAxis2.CENTER
795
+
796
+ pad_mode: str = "constant" # pad mode to be used for alignment
797
+
630
798
  def settle_inputs(self) -> None:
631
799
  self.settle_slices()
632
800
 
@@ -648,7 +816,11 @@ class ZStitchingConfiguration(StitchingConfiguration):
648
816
  {
649
817
  INPUTS_SECTION: {
650
818
  STITCHING_SLICES: slices,
651
- }
819
+ },
820
+ STITCHING_SECTION: {
821
+ ALIGNMENT_AXIS_2_FIELD: self.alignment_axis_2.value,
822
+ PAD_MODE_FIELD: self.pad_mode,
823
+ },
652
824
  },
653
825
  )
654
826
 
@@ -670,7 +842,7 @@ class PreProcessedZStitchingConfiguration(ZStitchingConfiguration):
670
842
  return StitchingType.Z_PREPROC
671
843
 
672
844
  def get_output_object(self):
673
- return HDF5TomoScan(
845
+ return NXtomoScan(
674
846
  scan=self.output_file_path,
675
847
  entry=self.output_data_path,
676
848
  )
@@ -850,6 +1022,11 @@ class PreProcessedZStitchingConfiguration(ZStitchingConfiguration):
850
1022
  config[STITCHING_SECTION].get(STITCHING_KERNELS_EXTRA_PARAMS, {}),
851
1023
  )
852
1024
  ),
1025
+ alignment_axis_2=AlignmentAxis2.from_value(
1026
+ config[STITCHING_SECTION].get(ALIGNMENT_AXIS_2_FIELD, AlignmentAxis2.CENTER)
1027
+ ),
1028
+ pad_mode=config[STITCHING_SECTION].get(PAD_MODE_FIELD, "constant"),
1029
+ normalization_by_sample=NormalizationBySample.from_dict(config.get(NORMALIZATION_BY_SAMPLE_SECTION, {})),
853
1030
  )
854
1031
 
855
1032
 
@@ -862,6 +1039,7 @@ class PostProcessedZStitchingConfiguration(ZStitchingConfiguration):
862
1039
  input_volumes: tuple = () # tuple of VolumeBase
863
1040
  output_volume: Optional[VolumeIdentifier] = None
864
1041
  voxel_size: Optional[float] = None
1042
+ alignment_axis_1: AlignmentAxis1 = AlignmentAxis1.CENTER
865
1043
 
866
1044
  @property
867
1045
  def stitching_type(self) -> StitchingType:
@@ -909,16 +1087,22 @@ class PostProcessedZStitchingConfiguration(ZStitchingConfiguration):
909
1087
  if len(self.input_volumes) == 0:
910
1088
  raise ValueError("No input volume provided. Cannot settle slices")
911
1089
  if slices is None:
912
- slices = slice(0, self.input_volumes[0].get_volume_shape()[1], 1)
1090
+ # before alignment was existing
1091
+ # slices = slice(0, self.input_volumes[0].get_volume_shape()[1], 1)
1092
+ slices = slice(
1093
+ 0,
1094
+ max([volume.get_volume_shape()[1] for volume in self.input_volumes]),
1095
+ 1,
1096
+ )
913
1097
  n_slices = slices.stop
914
1098
  if isinstance(slices, slice):
915
1099
  # force slices indices to be positive
916
1100
  start = slices.start
917
1101
  if start < 0:
918
- start += self.input_volumes[0].get_volume_shape()[1] + 1
1102
+ start += max([volume.get_volume_shape()[1] for volume in self.input_volumes]) + 1
919
1103
  stop = slices.stop
920
1104
  if stop < 0:
921
- stop += self.input_volumes[0].get_volume_shape()[1] + 1
1105
+ stop += max([volume.get_volume_shape()[1] for volume in self.input_volumes]) + 1
922
1106
  step = slices.step
923
1107
  if step is None:
924
1108
  step = 1
@@ -987,6 +1171,14 @@ class PostProcessedZStitchingConfiguration(ZStitchingConfiguration):
987
1171
  config[STITCHING_SECTION].get(STITCHING_KERNELS_EXTRA_PARAMS, {}),
988
1172
  )
989
1173
  ),
1174
+ alignment_axis_1=AlignmentAxis1.from_value(
1175
+ config[STITCHING_SECTION].get(ALIGNMENT_AXIS_1_FIELD, AlignmentAxis1.CENTER)
1176
+ ),
1177
+ alignment_axis_2=AlignmentAxis2.from_value(
1178
+ config[STITCHING_SECTION].get(ALIGNMENT_AXIS_2_FIELD, AlignmentAxis2.CENTER)
1179
+ ),
1180
+ pad_mode=config[STITCHING_SECTION].get(PAD_MODE_FIELD, "constant"),
1181
+ normalization_by_sample=NormalizationBySample.from_dict(config.get(NORMALIZATION_BY_SAMPLE_SECTION, {})),
990
1182
  )
991
1183
 
992
1184
  def to_dict(self):
@@ -1007,6 +1199,9 @@ class PostProcessedZStitchingConfiguration(ZStitchingConfiguration):
1007
1199
  if self.output_volume is not None
1008
1200
  else "",
1009
1201
  },
1202
+ STITCHING_SECTION: {
1203
+ ALIGNMENT_AXIS_1_FIELD: self.alignment_axis_1.value,
1204
+ },
1010
1205
  },
1011
1206
  )
1012
1207
 
@@ -1022,6 +1217,13 @@ class PostProcessedZStitchingConfiguration(ZStitchingConfiguration):
1022
1217
  "type": "required",
1023
1218
  },
1024
1219
  },
1220
+ STITCHING_SECTION: {
1221
+ ALIGNMENT_AXIS_1_FIELD: {
1222
+ "default": "center",
1223
+ "help": f"alignment to apply over axis 1 if needed. Valid values are {AlignmentAxis1.values()}",
1224
+ "type": "advanced",
1225
+ }
1226
+ },
1025
1227
  },
1026
1228
  )
1027
1229
 
@@ -0,0 +1,6 @@
1
+ from silx.utils.enum import Enum as _Enum
2
+
3
+
4
+ class StitchingType(_Enum):
5
+ Z_PREPROC = "z-preproc"
6
+ Z_POSTPROC = "z-postproc"
@@ -34,8 +34,10 @@ class ZFrameComposition(_FrameCompositionBase):
34
34
  )
35
35
 
36
36
  def compose(self, output_frame: numpy.ndarray, input_frames: tuple):
37
- if not output_frame.ndim == 2:
38
- raise TypeError(f"output_frame is expected to be 2D and not {output_frame.ndim}")
37
+ if not output_frame.ndim in (2, 3):
38
+ raise TypeError(
39
+ f"output_frame is expected to be 2D (gray scale) or 3D (RGB(A)) and not {output_frame.ndim}"
40
+ )
39
41
  for (
40
42
  global_start_y,
41
43
  global_end_y,
nabu/stitching/overlap.py CHANGED
@@ -29,11 +29,15 @@ __date__ = "10/05/2022"
29
29
 
30
30
 
31
31
  import numpy
32
- from typing import Optional
32
+ import logging
33
+ from typing import Optional, Union
33
34
  from silx.utils.enum import Enum as _Enum
34
35
  from nabu.misc import fourier_filters
35
36
  from scipy.fft import rfftn as local_fftn
36
37
  from scipy.fft import irfftn as local_ifftn
38
+ from tomoscan.utils.geometry import BoundingBox1D
39
+
40
+ _logger = logging.getLogger(__name__)
37
41
 
38
42
 
39
43
  class OverlapStitchingStrategy(_Enum):
@@ -211,7 +215,7 @@ def compute_image_minimum_divergence(img_1: numpy.ndarray, img_2: numpy.ndarray,
211
215
  It split the two images into two parts: high frequency and low frequency.
212
216
 
213
217
  The two low frequency part will be stitched using a 'sinusoidal' / cosinus weights approach.
214
- When the two high frequency part will be stitched by taking the lower divergent pixels
218
+ When the two high frequency parts will be stitched by taking the lower divergent pixels
215
219
  """
216
220
 
217
221
  # split low and high frequencies
@@ -268,10 +272,102 @@ def compute_image_minimum_divergence(img_1: numpy.ndarray, img_2: numpy.ndarray,
268
272
 
269
273
  def compute_image_higher_signal(img_1: numpy.ndarray, img_2: numpy.ndarray):
270
274
  """
271
- the higher signal will pick pixel on the image have the higher signal.
275
+ the higher signal will pick pixel on the image having the higher signal.
272
276
  A use case is that if there is some artefacts on images which creates stripes (from scintillator artefacts for example)
273
277
  it could be removed from this method
274
278
  """
275
279
  # note: to be think about. But maybe it can be interesting to rescale img_1 and img_2
276
280
  # to ge something more coherent
277
281
  return numpy.where(img_1 >= img_2, img_1, img_2)
282
+
283
+
284
+ def check_overlaps(frames: Union[tuple, numpy.ndarray], positions: tuple, axis: int, raise_error: bool):
285
+ """
286
+ check over frames if there is a single overlap other juxtaposed frames (at most and at least)
287
+
288
+ :param frames: liste of ordered / sorted frames along axis to test (from higher to lower)
289
+ :param positions: positions of frames in 3D space as (position axis 0, position axis 1, position axis 2)
290
+ :param axis: axis to check
291
+ :param raise_error: if True then raise an error if two frames don't have at least and at most one overlap. Else log an error
292
+ """
293
+ if not isinstance(frames, (tuple, numpy.ndarray)):
294
+ raise TypeError(f"frames is expected to be a tuple or a numpy array. Get {type(frames)} instead")
295
+ if not isinstance(positions, tuple):
296
+ raise TypeError(f"positions is expected to be a tuple. Get {type(positions)} instead")
297
+ assert isinstance(axis, int), "axis is expected to be an int"
298
+ assert isinstance(raise_error, bool), "raise_error is expected to be a bool"
299
+
300
+ def treat_error(error_msg: str):
301
+ if raise_error:
302
+ raise ValueError(error_msg)
303
+ else:
304
+ _logger.error(raise_error)
305
+
306
+ # convert each frame to appropriate bounding box according to the axis
307
+ def convert_to_bb(frame: numpy.ndarray, position: tuple, axis: int):
308
+ assert isinstance(axis, int)
309
+ assert isinstance(position, tuple), f"position expected a tuple. Get {type(position)} instead"
310
+ start_frame = position[axis] - frame.shape[axis] // 2
311
+ end_frame = start_frame + frame.shape[axis]
312
+ return BoundingBox1D(start_frame, end_frame)
313
+
314
+ bounding_boxes = {
315
+ convert_to_bb(frame=frame, position=position, axis=axis): position for frame, position in zip(frames, positions)
316
+ }
317
+
318
+ def get_frame_index(my_bb) -> str:
319
+ bb_index = tuple(bounding_boxes.keys()).index(my_bb) + 1
320
+ if bb_index in (1, 21, 31):
321
+ return f"{bb_index}st"
322
+ elif bb_index in (2, 22, 32):
323
+ return f"{bb_index}nd"
324
+ elif bb_index == (3, 23, 33):
325
+ return f"{bb_index}rd"
326
+ else:
327
+ return f"{bb_index}th"
328
+
329
+ # check that theres an overlap between two juxtaposed bb (or frame at the end)
330
+ all_bounding_boxes = tuple(bounding_boxes.keys())
331
+ bb_with_expected_overlap = [
332
+ (bb_frame, bb_next_frame) for bb_frame, bb_next_frame in zip(all_bounding_boxes[:-1], all_bounding_boxes[1:])
333
+ ]
334
+
335
+ for bb_pair in bb_with_expected_overlap:
336
+ bb_frame, bb_next_frame = bb_pair
337
+ if bb_frame.max < bb_next_frame.min:
338
+ treat_error(f"provided frames seems un sorted (from the higher to the lower)")
339
+ if bb_frame.min < bb_next_frame.min:
340
+ treat_error(
341
+ f"Seems like {get_frame_index(bb_frame)} frame is fully overlaping with frame {get_frame_index(bb_next_frame)}"
342
+ )
343
+ if bb_frame.get_overlap(bb_next_frame) is None:
344
+ treat_error(
345
+ f"no overlap found between two juxtaposed frames - {get_frame_index(bb_frame)} and {get_frame_index(bb_next_frame)}"
346
+ )
347
+
348
+ # check there is no overlap between none juxtaposed bb
349
+ def pick_all_none_juxtaposed_bb(index, my_bounding_boxes: tuple):
350
+ """return all the bounding boxes to check for the index 'index':
351
+
352
+ :return: (tested_bounding_box, bounding_boxes_to_test)
353
+ """
354
+ my_bounding_boxes = {bb_index: bb for bb_index, bb in enumerate(my_bounding_boxes)}
355
+ bounding_boxes = dict(
356
+ filter(
357
+ lambda pair: pair[0] not in (index - 1, index, index + 1),
358
+ my_bounding_boxes.items(),
359
+ )
360
+ )
361
+ return my_bounding_boxes[index], bounding_boxes.values()
362
+
363
+ bb_without_expected_overlap = [
364
+ pick_all_none_juxtaposed_bb(index, all_bounding_boxes) for index in range(len(all_bounding_boxes))
365
+ ]
366
+
367
+ for bb_pair in bb_without_expected_overlap:
368
+ bb_frame, bb_not_juxtaposed_frames = bb_pair
369
+ for bb_not_juxtaposed_frame in bb_not_juxtaposed_frames:
370
+ if bb_frame.get_overlap(bb_not_juxtaposed_frame) is not None:
371
+ treat_error(
372
+ f"overlap found between two frames not juxtaposed - {bounding_boxes[bb_frame]} and {bounding_boxes[bb_not_juxtaposed_frame]}"
373
+ )
@@ -0,0 +1,60 @@
1
+ import numpy
2
+ from silx.utils.enum import Enum as _Enum
3
+
4
+
5
+ class SampleSide(_Enum):
6
+ LEFT = "left"
7
+ RIGHT = "right"
8
+
9
+
10
+ class Method(_Enum):
11
+ MEAN = "mean"
12
+ MEDIAN = "median"
13
+
14
+
15
+ def normalize_frame(
16
+ frame: numpy.ndarray, side: SampleSide, method: Method, sample_width: int = 50, margin_before_sample: int = 0
17
+ ):
18
+ """
19
+ normalize the frame from a sample section picked at the left of the right of the frame
20
+
21
+ :param frame: frame to normalize
22
+ :param SampleSide side: side to pick the sample
23
+ :param Method method: normalization method
24
+ :param int sample_width: sample width
25
+ :param int margin: margin before the sampling area
26
+ """
27
+ if not isinstance(frame, numpy.ndarray):
28
+ raise TypeError(f"Frame is expected to be a 2D numpy array.")
29
+ if frame.ndim != 2:
30
+ raise TypeError(f"Frame is expected to be a 2D numpy array. Get {frame.ndim}D")
31
+ side = SampleSide.from_value(side)
32
+ method = Method.from_value(method)
33
+
34
+ if frame.shape[1] < sample_width + margin_before_sample:
35
+ raise ValueError(
36
+ f"frame width ({frame.shape[1]}) < sample_width + margin ({sample_width + margin_before_sample})"
37
+ )
38
+
39
+ # create sample
40
+ if side is SampleSide.LEFT:
41
+ sample_start = margin_before_sample
42
+ sample_end = margin_before_sample + sample_width
43
+ sample = frame[:, sample_start:sample_end]
44
+ elif side is SampleSide.RIGHT:
45
+ sample_start = frame.shape[1] - (sample_width + margin_before_sample)
46
+ sample_end = frame.shape[1] - margin_before_sample
47
+ sample = frame[:, sample_start:sample_end]
48
+ else:
49
+ raise ValueError(f"side {side.value} not handled")
50
+
51
+ # do normalization
52
+ if method is Method.MEAN:
53
+ normalization_array = numpy.mean(sample, axis=1)
54
+ elif method is Method.MEDIAN:
55
+ normalization_array = numpy.median(sample, axis=1)
56
+ else:
57
+ raise ValueError(f"side {side.value} not handled")
58
+ for line in range(normalization_array.shape[0]):
59
+ frame[line, :] -= normalization_array[line]
60
+ return frame