nabu 2025.1.0.dev14__py3-none-any.whl → 2025.1.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. nabu/__init__.py +1 -1
  2. nabu/app/cast_volume.py +12 -1
  3. nabu/app/cli_configs.py +80 -3
  4. nabu/app/estimate_motion.py +54 -0
  5. nabu/app/multicor.py +2 -4
  6. nabu/app/pcaflats.py +116 -0
  7. nabu/app/reconstruct.py +1 -7
  8. nabu/app/reduce_dark_flat.py +5 -2
  9. nabu/estimation/cor.py +1 -1
  10. nabu/estimation/motion.py +557 -0
  11. nabu/estimation/tests/test_motion_estimation.py +471 -0
  12. nabu/estimation/tilt.py +1 -1
  13. nabu/estimation/translation.py +47 -1
  14. nabu/io/cast_volume.py +94 -13
  15. nabu/io/reader.py +32 -1
  16. nabu/io/tests/test_remove_volume.py +152 -0
  17. nabu/pipeline/config_validators.py +42 -43
  18. nabu/pipeline/estimators.py +255 -0
  19. nabu/pipeline/fullfield/chunked.py +67 -43
  20. nabu/pipeline/fullfield/chunked_cuda.py +5 -2
  21. nabu/pipeline/fullfield/nabu_config.py +17 -11
  22. nabu/pipeline/fullfield/processconfig.py +8 -2
  23. nabu/pipeline/fullfield/reconstruction.py +3 -0
  24. nabu/pipeline/params.py +12 -0
  25. nabu/pipeline/tests/test_estimators.py +240 -3
  26. nabu/preproc/ccd.py +53 -3
  27. nabu/preproc/flatfield.py +306 -1
  28. nabu/preproc/shift.py +3 -1
  29. nabu/preproc/tests/test_pcaflats.py +154 -0
  30. nabu/processing/rotation_cuda.py +3 -1
  31. nabu/processing/tests/test_rotation.py +4 -2
  32. nabu/reconstruction/fbp.py +7 -0
  33. nabu/reconstruction/fbp_base.py +31 -7
  34. nabu/reconstruction/fbp_opencl.py +8 -0
  35. nabu/reconstruction/filtering_opencl.py +2 -0
  36. nabu/reconstruction/mlem.py +47 -13
  37. nabu/reconstruction/tests/test_filtering.py +13 -2
  38. nabu/reconstruction/tests/test_mlem.py +91 -62
  39. nabu/resources/dataset_analyzer.py +144 -20
  40. nabu/resources/nxflatfield.py +101 -35
  41. nabu/resources/tests/test_nxflatfield.py +1 -1
  42. nabu/resources/utils.py +16 -10
  43. nabu/stitching/alignment.py +7 -7
  44. nabu/stitching/config.py +22 -20
  45. nabu/stitching/definitions.py +2 -2
  46. nabu/stitching/overlap.py +4 -4
  47. nabu/stitching/sample_normalization.py +5 -5
  48. nabu/stitching/stitcher/post_processing.py +5 -3
  49. nabu/stitching/stitcher/pre_processing.py +24 -20
  50. nabu/stitching/tests/test_config.py +3 -3
  51. nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
  52. nabu/stitching/tests/test_z_postprocessing_stitching.py +2 -2
  53. nabu/stitching/tests/test_z_preprocessing_stitching.py +23 -20
  54. nabu/stitching/utils/utils.py +7 -7
  55. nabu/testutils.py +1 -4
  56. nabu/utils.py +13 -0
  57. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc1.dist-info}/METADATA +3 -4
  58. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc1.dist-info}/RECORD +62 -57
  59. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc1.dist-info}/WHEEL +1 -1
  60. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc1.dist-info}/entry_points.txt +2 -1
  61. nabu/app/correct_rot.py +0 -62
  62. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc1.dist-info}/licenses/LICENSE +0 -0
  63. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc1.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,8 @@ from silx.io import get_data
6
6
  from tomoscan.framereducer.reducedframesinfos import ReducedFramesInfos
7
7
  from tomoscan.esrf.scan.nxtomoscan import NXtomoScan
8
8
  from ..utils import check_supported, is_writeable
9
+ from ..preproc.flatfield import PCAFlatsDecomposer
10
+ from ..io.reader import NXDarksFlats
9
11
 
10
12
 
11
13
  def get_frame_possible_urls(dataset_info, user_dir, output_dir):
@@ -22,7 +24,7 @@ def get_frame_possible_urls(dataset_info, user_dir, output_dir):
22
24
  Output processing directory
23
25
  """
24
26
 
25
- frame_types = ["flats", "darks"]
27
+ frame_types = ["flats", "darks", "pcaflats"]
26
28
  h5scan = dataset_info.dataset_scanner # tomoscan object
27
29
 
28
30
  def make_dataurl(dirname, frame_type):
@@ -32,8 +34,10 @@ def get_frame_possible_urls(dataset_info, user_dir, output_dir):
32
34
 
33
35
  if frame_type == "flats":
34
36
  dataurl_default_template = h5scan.REDUCED_FLATS_DATAURLS[0]
35
- else:
37
+ elif frame_type == "darks":
36
38
  dataurl_default_template = h5scan.REDUCED_DARKS_DATAURLS[0]
39
+ elif frame_type == "pcaflats":
40
+ dataurl_default_template = h5scan.PCA_FLATS_DATAURLS[0]
37
41
 
38
42
  rel_file_path = dataurl_default_template.file_path().format(scan_prefix=h5scan.get_dataset_basename())
39
43
  return DataUrl(
@@ -149,8 +153,94 @@ def data_url_exists(data_url):
149
153
  return group_exists
150
154
 
151
155
 
156
+ def _compute_and_save_reduced_frames(flatfield_mode, dataset_info, reduced_frames_urls):
157
+ if flatfield_mode == "pca":
158
+ dfreader = NXDarksFlats(dataset_info.location)
159
+ darks = np.concatenate([d for d in dfreader.get_raw_darks()], axis=0)
160
+ flats = np.concatenate([f for f in dfreader.get_raw_flats()], axis=0)
161
+ pcaflats_darks = PCAFlatsDecomposer(flats, darks)
162
+
163
+ # Get "where to write". tomoscan expects a DataUrl
164
+ pcaflats_dir_url = reduced_frames_urls.get("user", None)
165
+ if pcaflats_dir_url is not None:
166
+ output_url = pcaflats_dir_url
167
+ elif is_writeable(os.path.dirname(reduced_frames_urls["dataset"]["flats"].file_path())):
168
+ output_url = reduced_frames_urls["dataset"]
169
+ else:
170
+ output_url = reduced_frames_urls["output"]
171
+ pcaflats_darks.save_decomposition(
172
+ path=output_url["pcaflats"].file_path(), entry=output_url["pcaflats"].data_path().strip("/").split("/")[0]
173
+ )
174
+ dataset_info.logger.info("PCA flats computed and written at %s" % (output_url["pcaflats"].file_path()))
175
+
176
+ # Update dataset_info with pca flats and dark
177
+ dataset_info.darks = {0: pcaflats_darks.dark}
178
+ flats = {0: pcaflats_darks.mean}
179
+ for k in range(len(pcaflats_darks.components)):
180
+ flats.update({k + 1: pcaflats_darks.components[k]})
181
+ dataset_info.flats = flats
182
+ else:
183
+ try:
184
+ dataset_info.flats = dataset_info.get_reduced_flats()
185
+ dataset_info.darks = dataset_info.get_reduced_darks()
186
+ except FileNotFoundError:
187
+ msg = "Could not find any flats and/or darks"
188
+ raise FileNotFoundError(msg)
189
+ _, flats_info, darks_info = save_reduced_frames(
190
+ dataset_info, {"darks": dataset_info.darks, "flats": dataset_info.flats}, reduced_frames_urls
191
+ )
192
+ dataset_info.flats_srcurrent = flats_info.machine_electric_current
193
+
194
+
195
+ def _load_existing_flatfields(dataset_info, reduced_frames_urls, frames_types, where_to_load_from):
196
+ if "pcaflats" not in frames_types:
197
+ reduced_frames_with_info = {}
198
+ for frame_type in frames_types:
199
+ reduced_frames_with_info[frame_type] = tomoscan_load_reduced_frames(
200
+ dataset_info, frame_type, reduced_frames_urls[where_to_load_from][frame_type]
201
+ )
202
+ dataset_info.logger.info(
203
+ "Loaded %s from %s" % (frame_type, reduced_frames_urls[where_to_load_from][frame_type].file_path())
204
+ )
205
+ red_frames_dict, red_frames_info = reduced_frames_with_info[frame_type]
206
+ setattr(
207
+ dataset_info,
208
+ frame_type,
209
+ {k: get_data(red_frames_dict[k]) for k in red_frames_dict},
210
+ )
211
+ if frame_type == "flats":
212
+ dataset_info.flats_srcurrent = red_frames_info.machine_electric_current
213
+ else:
214
+ df_path = reduced_frames_urls[where_to_load_from]["pcaflats"].file_path()
215
+ entry = reduced_frames_urls[where_to_load_from]["pcaflats"].data_path()
216
+
217
+ # Update dark
218
+ dark_url = DataUrl(f"silx://{df_path}?{entry}/dark")
219
+ dark = get_data(dark_url)
220
+ setattr(
221
+ dataset_info,
222
+ "dark",
223
+ {0: dark},
224
+ )
225
+ # Update flats with principal compenents
226
+ # Take mean as first comp., mask as second, flats thereafter
227
+ flats_url = DataUrl(f"silx://{df_path}?{entry}/p_components")
228
+ mean_url = DataUrl(f"silx://{df_path}?{entry}/p_mean")
229
+ flats = get_data(flats_url)
230
+ mean = get_data(mean_url)
231
+ flats = np.concatenate([mean[np.newaxis], flats], axis=0)
232
+ setattr(
233
+ dataset_info,
234
+ "flats",
235
+ {k: flats[k] for k in range(len(flats))},
236
+ )
237
+ dataset_info.logger.info("Loaded %s from %s" % ("PCA darks/flats", df_path))
238
+
239
+
152
240
  # pylint: disable=E1136
153
- def update_dataset_info_flats_darks(dataset_info, flatfield_mode, output_dir=None, darks_flats_dir=None):
241
+ def update_dataset_info_flats_darks(
242
+ dataset_info, flatfield_mode, loading_mode="load_if_present", output_dir=None, darks_flats_dir=None
243
+ ):
154
244
  """
155
245
  Update a DatasetAnalyzer object with reduced flats/darks (hereafter "reduced frames").
156
246
 
@@ -170,23 +260,14 @@ def update_dataset_info_flats_darks(dataset_info, flatfield_mode, output_dir=Non
170
260
  if flatfield_mode is False:
171
261
  return
172
262
 
173
- frames_types = ["darks", "flats"]
263
+ if flatfield_mode == "pca":
264
+ frames_types = ["pcaflats"]
265
+ else:
266
+ frames_types = ["darks", "flats"]
174
267
  reduced_frames_urls = get_frame_possible_urls(dataset_info, darks_flats_dir, output_dir)
175
268
 
176
- def _compute_and_save_reduced_frames():
177
- try:
178
- dataset_info.flats = dataset_info.get_reduced_flats()
179
- dataset_info.darks = dataset_info.get_reduced_darks()
180
- except FileNotFoundError:
181
- msg = "Could not find any flats and/or darks"
182
- raise FileNotFoundError(msg)
183
- _, flats_info, darks_info = save_reduced_frames(
184
- dataset_info, {"darks": dataset_info.darks, "flats": dataset_info.flats}, reduced_frames_urls
185
- )
186
- dataset_info.flats_srcurrent = flats_info.machine_electric_current
187
-
188
- if flatfield_mode == "force-compute":
189
- _compute_and_save_reduced_frames()
269
+ if loading_mode == "force-compute":
270
+ _compute_and_save_reduced_frames(flatfield_mode, dataset_info, reduced_frames_urls)
190
271
  return
191
272
 
192
273
  def _can_load_from(folder_type):
@@ -206,21 +287,6 @@ def update_dataset_info_flats_darks(dataset_info, flatfield_mode, output_dir=Non
206
287
  raise ValueError("Could not load darks/flats (using 'force-load')")
207
288
 
208
289
  if where_to_load_from is not None:
209
- reduced_frames_with_info = {}
210
- for frame_type in frames_types:
211
- reduced_frames_with_info[frame_type] = tomoscan_load_reduced_frames(
212
- dataset_info, frame_type, reduced_frames_urls[where_to_load_from][frame_type]
213
- )
214
- dataset_info.logger.info(
215
- "Loaded %s from %s" % (frame_type, reduced_frames_urls[where_to_load_from][frame_type].file_path())
216
- )
217
- red_frames_dict, red_frames_info = reduced_frames_with_info[frame_type]
218
- setattr(
219
- dataset_info,
220
- frame_type,
221
- {k: get_data(red_frames_dict[k]) for k in red_frames_dict},
222
- )
223
- if frame_type == "flats":
224
- dataset_info.flats_srcurrent = red_frames_info.machine_electric_current
290
+ _load_existing_flatfields(dataset_info, reduced_frames_urls, frames_types, where_to_load_from)
225
291
  else:
226
- _compute_and_save_reduced_frames()
292
+ _compute_and_save_reduced_frames(flatfield_mode, dataset_info, reduced_frames_urls)
@@ -80,7 +80,7 @@ class TestNXFlatField:
80
80
  output_dir = self.params.get("output_dir", None)
81
81
  if output_dir is not None:
82
82
  output_dir = output_dir.format(tempdir=self.tempdir)
83
- update_dataset_info_flats_darks(dataset_info, True, output_dir=output_dir)
83
+ update_dataset_info_flats_darks(dataset_info, True, loading_mode="load_if_present", output_dir=output_dir)
84
84
  # After reduction (median/mean), the flats/darks are located in another file.
85
85
  # median(series_1) goes to entry/flats/idx1, mean(series_2) goes to entry/flats/idx2, etc.
86
86
  assert set(dataset_info.flats.keys()) == set(s.start for s in self.params["flats_pos"]) # noqa: C401
nabu/resources/utils.py CHANGED
@@ -1,8 +1,9 @@
1
1
  from ast import literal_eval
2
2
  import numpy as np
3
+ import pint
3
4
  from psutil import virtual_memory, cpu_count
4
- from pyunitsystem.metricsystem import MetricSystem
5
- from pyunitsystem.energysystem import EnergySI
5
+
6
+ _ureg = pint.get_application_registry()
6
7
 
7
8
 
8
9
  def get_values_from_file(fname, n_values=None, shape=None, sep=None, any_size=False):
@@ -163,12 +164,17 @@ def get_quantities_and_units(string, sep=";"):
163
164
  value, unit = value_and_unit.split()
164
165
  val = float(value)
165
166
  # Convert to SI
166
- try:
167
- # handle metrics
168
- conversion_factor = MetricSystem.from_str(unit).value
169
- except ValueError:
170
- # handle energies
171
- conversion_factor = EnergySI.from_str(unit).value / EnergySI.KILOELECTRONVOLT.value
172
-
173
- result[quantity_name] = val * conversion_factor
167
+ if unit.lower() == "kev":
168
+ current_unit = _ureg.keV
169
+ elif unit.lower() == "ev":
170
+ current_unit = _ureg.eV
171
+ else:
172
+ current_unit = _ureg(unit)
173
+ # handle energies (to move to keV)
174
+ if _ureg.keV.dimensionality == current_unit.dimensionality:
175
+ result[quantity_name] = (val * current_unit).to(_ureg.keV).magnitude
176
+ elif _ureg.meter.dimensionality == current_unit.dimensionality:
177
+ result[quantity_name] = (val * current_unit).to_base_units().magnitude
178
+ else:
179
+ raise ValueError(f"Cannot convert: {unit}")
174
180
  return result
@@ -1,10 +1,10 @@
1
+ from enum import Enum
1
2
  import h5py
2
3
  import numpy
3
4
  from typing import Union
4
- from silx.utils.enum import Enum as _Enum
5
5
 
6
6
 
7
- class AlignmentAxis2(_Enum):
7
+ class AlignmentAxis2(Enum):
8
8
  """Specific alignment named to help users orienting themself with specific name"""
9
9
 
10
10
  CENTER = "center"
@@ -12,7 +12,7 @@ class AlignmentAxis2(_Enum):
12
12
  RIGTH = "right"
13
13
 
14
14
 
15
- class AlignmentAxis1(_Enum):
15
+ class AlignmentAxis1(Enum):
16
16
  """Specific alignment named to help users orienting themself with specific name"""
17
17
 
18
18
  FRONT = "front"
@@ -20,7 +20,7 @@ class AlignmentAxis1(_Enum):
20
20
  BACK = "back"
21
21
 
22
22
 
23
- class _Alignment(_Enum):
23
+ class _Alignment(Enum):
24
24
  """Internal alignment to be used for 2D alignment"""
25
25
 
26
26
  LOWER_BOUNDARY = "lower boundary"
@@ -29,7 +29,7 @@ class _Alignment(_Enum):
29
29
 
30
30
  @classmethod
31
31
  def from_value(cls, value):
32
- # cast the AlignmentAxis1 and AlignmentAxis2 values to fit the generic definition
32
+ # cast the AlignmentAxis1 and AlignmentAxis2 values to fit the generic definition.
33
33
  if value in ("front", "left", AlignmentAxis1.FRONT, AlignmentAxis2.LEFT):
34
34
  return _Alignment.LOWER_BOUNDARY
35
35
  elif value in ("back", "right", AlignmentAxis1.BACK, AlignmentAxis2.RIGTH):
@@ -37,7 +37,7 @@ class _Alignment(_Enum):
37
37
  elif value in (AlignmentAxis1.CENTER, AlignmentAxis2.CENTER):
38
38
  return _Alignment.CENTER
39
39
  else:
40
- return super().from_value(value)
40
+ return super().__new__(cls, value)
41
41
 
42
42
 
43
43
  def align_frame(
@@ -103,7 +103,7 @@ def align_horizontally(data: numpy.ndarray, alignment: AlignmentAxis2, new_width
103
103
  :param HAlignment alignment: alignment strategy
104
104
  :param int new_width: output data width
105
105
  """
106
- alignment = AlignmentAxis2.from_value(alignment).value
106
+ alignment = AlignmentAxis2(alignment).value
107
107
  return align_frame(
108
108
  data=data, alignment=alignment, new_aligned_axis_size=new_width, pad_mode=pad_mode, alignment_axis=1
109
109
  )
nabu/stitching/config.py CHANGED
@@ -1,9 +1,9 @@
1
+ import pint
1
2
  from math import ceil
2
3
  from typing import Optional, Union
3
4
  from collections.abc import Iterable, Sized
4
5
  from dataclasses import dataclass
5
6
  import numpy
6
- from pyunitsystem.metricsystem import MetricSystem
7
7
  from nxtomo.paths import nxtomo
8
8
  from tomoscan.factory import Factory
9
9
  from tomoscan.identifier import VolumeIdentifier, ScanIdentifier
@@ -18,6 +18,8 @@ from .utils.utils import ShiftAlgorithm
18
18
  from .definitions import StitchingType
19
19
  from .alignment import AlignmentAxis1, AlignmentAxis2
20
20
 
21
+ _ureg = pint.get_application_registry()
22
+
21
23
  # ruff: noqa: S105
22
24
 
23
25
  KEY_IMG_REG_METHOD = "img_reg_method"
@@ -220,7 +222,7 @@ def str_to_shifts(my_str: Optional[str]) -> Union[str, tuple]:
220
222
  if my_str == "":
221
223
  return None
222
224
  try:
223
- shift = ShiftAlgorithm.from_value(my_str)
225
+ shift = ShiftAlgorithm(my_str)
224
226
  except ValueError:
225
227
  shifts_as_str = filter(None, my_str.replace(";", ",").split(","))
226
228
  return [float(shift) for shift in shifts_as_str]
@@ -336,7 +338,7 @@ class NormalizationBySample:
336
338
 
337
339
  @method.setter
338
340
  def method(self, method: Union[Method, str]) -> None:
339
- self._method = Method.from_value(method)
341
+ self._method = Method(method)
340
342
 
341
343
  @property
342
344
  def margin(self) -> int:
@@ -353,7 +355,7 @@ class NormalizationBySample:
353
355
 
354
356
  @side.setter
355
357
  def side(self, side: Union[SampleSide, str]):
356
- self._side = SampleSide.from_value(side)
358
+ self._side = SampleSide(side)
357
359
 
358
360
  @property
359
361
  def width(self) -> int:
@@ -548,12 +550,12 @@ class StitchingConfiguration:
548
550
  STITCHING_SECTION: {
549
551
  STITCHING_TYPE_FIELD: {
550
552
  "default": StitchingType.Z_PREPROC.value,
551
- "help": f"stitching to be applied. Must be in {StitchingType.values()}",
553
+ "help": f"stitching to be applied. Must be in {[st.value for st in StitchingType]}",
552
554
  "type": "required",
553
555
  },
554
556
  STITCHING_STRATEGY_FIELD: {
555
557
  "default": "cosinus weights",
556
- "help": f"Policy to apply to compute the overlap area. Must be in {OverlapStitchingStrategy.values()}.",
558
+ "help": f"Policy to apply to compute the overlap area. Must be in {[ov.value for ov in OverlapStitchingStrategy]}.",
557
559
  "type": "required",
558
560
  },
559
561
  CROSS_CORRELATION_SLICE_FIELD: {
@@ -633,7 +635,7 @@ class StitchingConfiguration:
633
635
  },
634
636
  ALIGNMENT_AXIS_2_FIELD: {
635
637
  "default": "center",
636
- "help": f"In case frame have different frame widths how to align them (so along volume axis 2). Valid keys are {AlignmentAxis2.values()}",
638
+ "help": f"In case frame have different frame widths how to align them (so along volume axis 2). Valid keys are {[aa.value for aa in AlignmentAxis2]}",
637
639
  "type": "advanced",
638
640
  },
639
641
  PAD_MODE_FIELD: {
@@ -755,7 +757,7 @@ class StitchingConfiguration:
755
757
  AXIS_2_POS_PX: _cast_shift_to_str(self.axis_2_pos_px),
756
758
  AXIS_2_POS_MM: _cast_shift_to_str(self.axis_2_pos_mm),
757
759
  AXIS_2_PARAMS: _dict_to_str(self.axis_2_params or {}),
758
- STITCHING_STRATEGY_FIELD: OverlapStitchingStrategy.from_value(self.stitching_strategy).value,
760
+ STITCHING_STRATEGY_FIELD: OverlapStitchingStrategy(self.stitching_strategy).value,
759
761
  FLIP_UD: self.flip_ud,
760
762
  FLIP_LR: self.flip_lr,
761
763
  RESCALE_FRAMES: self.rescale_frames,
@@ -934,7 +936,7 @@ class PreProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigurat
934
936
  if self.pixel_size is None:
935
937
  pixel_size_mm = ""
936
938
  else:
937
- pixel_size_mm = self.pixel_size * MetricSystem.MILLIMETER.value
939
+ pixel_size_mm = (self.pixel_size * _ureg.meter).to(_ureg.millimeter).magnitude
938
940
  return concatenate_dict(
939
941
  super().to_dict(),
940
942
  {
@@ -998,10 +1000,10 @@ class PreProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigurat
998
1000
  if pixel_size == "":
999
1001
  pixel_size = None
1000
1002
  else:
1001
- pixel_size = float(pixel_size) / MetricSystem.MM
1003
+ pixel_size = (float(pixel_size) * _ureg.millimeter).to_base_units().magnitude
1002
1004
 
1003
1005
  return cls(
1004
- stitching_strategy=OverlapStitchingStrategy.from_value(
1006
+ stitching_strategy=OverlapStitchingStrategy(
1005
1007
  config[STITCHING_SECTION].get(
1006
1008
  STITCHING_STRATEGY_FIELD,
1007
1009
  OverlapStitchingStrategy.COSINUS_WEIGHTS,
@@ -1042,7 +1044,7 @@ class PreProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigurat
1042
1044
  config[STITCHING_SECTION].get(STITCHING_KERNELS_EXTRA_PARAMS, {}),
1043
1045
  )
1044
1046
  ),
1045
- alignment_axis_2=AlignmentAxis2.from_value(
1047
+ alignment_axis_2=AlignmentAxis2(
1046
1048
  config[STITCHING_SECTION].get(ALIGNMENT_AXIS_2_FIELD, AlignmentAxis2.CENTER)
1047
1049
  ),
1048
1050
  pad_mode=config[STITCHING_SECTION].get(PAD_MODE_FIELD, "constant"),
@@ -1163,11 +1165,11 @@ class PostProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigura
1163
1165
  if voxel_size == "":
1164
1166
  voxel_size = None
1165
1167
  else:
1166
- voxel_size = float(voxel_size) * MetricSystem.MM
1168
+ voxel_size = (float(voxel_size) * _ureg.millimeter).to_base_units().magnitude
1167
1169
 
1168
1170
  # on the next section the one with a default value qre the optional one
1169
1171
  return cls(
1170
- stitching_strategy=OverlapStitchingStrategy.from_value(
1172
+ stitching_strategy=OverlapStitchingStrategy(
1171
1173
  config[STITCHING_SECTION].get(
1172
1174
  STITCHING_STRATEGY_FIELD,
1173
1175
  OverlapStitchingStrategy.COSINUS_WEIGHTS,
@@ -1198,10 +1200,10 @@ class PostProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigura
1198
1200
  config[STITCHING_SECTION].get(STITCHING_KERNELS_EXTRA_PARAMS, {}),
1199
1201
  )
1200
1202
  ),
1201
- alignment_axis_1=AlignmentAxis1.from_value(
1203
+ alignment_axis_1=AlignmentAxis1(
1202
1204
  config[STITCHING_SECTION].get(ALIGNMENT_AXIS_1_FIELD, AlignmentAxis1.CENTER)
1203
1205
  ),
1204
- alignment_axis_2=AlignmentAxis2.from_value(
1206
+ alignment_axis_2=AlignmentAxis2(
1205
1207
  config[STITCHING_SECTION].get(ALIGNMENT_AXIS_2_FIELD, AlignmentAxis2.CENTER)
1206
1208
  ),
1207
1209
  pad_mode=config[STITCHING_SECTION].get(PAD_MODE_FIELD, "constant"),
@@ -1215,7 +1217,7 @@ class PostProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigura
1215
1217
  if self.voxel_size is None:
1216
1218
  voxel_size_mm = ""
1217
1219
  else:
1218
- voxel_size_mm = numpy.array(self.voxel_size) / MetricSystem.MM
1220
+ voxel_size_mm = numpy.array((self.voxel_size * _ureg.meter).to(_ureg.millimeter).magnitude)
1219
1221
 
1220
1222
  return concatenate_dict(
1221
1223
  super().to_dict(),
@@ -1250,7 +1252,7 @@ class PostProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigura
1250
1252
  STITCHING_SECTION: {
1251
1253
  ALIGNMENT_AXIS_1_FIELD: {
1252
1254
  "default": "center",
1253
- "help": f"alignment to apply over axis 1 if needed. Valid values are {AlignmentAxis1.values()}",
1255
+ "help": f"alignment to apply over axis 1 if needed. Valid values are {[aa for aa in AlignmentAxis1]}",
1254
1256
  "type": "advanced",
1255
1257
  }
1256
1258
  },
@@ -1281,7 +1283,7 @@ def dict_to_config_obj(config: dict):
1281
1283
  if stitching_type is None:
1282
1284
  raise ValueError("Unable to find stitching type from config dict")
1283
1285
  else:
1284
- stitching_type = StitchingType.from_value(stitching_type)
1286
+ stitching_type = StitchingType(stitching_type)
1285
1287
  if stitching_type is StitchingType.Z_POSTPROC:
1286
1288
  return PostProcessedZStitchingConfiguration.from_dict(config)
1287
1289
  elif stitching_type is StitchingType.Z_PREPROC:
@@ -1304,7 +1306,7 @@ def get_default_stitching_config(stitching_type: Optional[Union[StitchingType, s
1304
1306
  if stitching_type is None:
1305
1307
  return concatenate_dict(z_postproc_stitching_config, z_preproc_stitching_config)
1306
1308
 
1307
- stitching_type = StitchingType.from_value(stitching_type)
1309
+ stitching_type = StitchingType(stitching_type)
1308
1310
  if stitching_type is StitchingType.Z_POSTPROC:
1309
1311
  return z_postproc_stitching_config
1310
1312
  elif stitching_type is StitchingType.Z_PREPROC:
@@ -1,7 +1,7 @@
1
- from silx.utils.enum import Enum as _Enum
1
+ from enum import Enum
2
2
 
3
3
 
4
- class StitchingType(_Enum):
4
+ class StitchingType(Enum):
5
5
  Y_PREPROC = "y-preproc"
6
6
  Z_PREPROC = "z-preproc"
7
7
  Z_POSTPROC = "z-postproc"
nabu/stitching/overlap.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import numpy
2
2
  import logging
3
3
  from typing import Optional, Union
4
- from silx.utils.enum import Enum as _Enum
4
+ from enum import Enum
5
5
  from nabu.misc import fourier_filters
6
6
  from scipy.fft import rfftn as local_fftn
7
7
  from scipy.fft import irfftn as local_ifftn
@@ -10,7 +10,7 @@ from tomoscan.utils.geometry import BoundingBox1D
10
10
  _logger = logging.getLogger(__name__)
11
11
 
12
12
 
13
- class OverlapStitchingStrategy(_Enum):
13
+ class OverlapStitchingStrategy(Enum):
14
14
  MEAN = "mean"
15
15
  COSINUS_WEIGHTS = "cosinus weights"
16
16
  LINEAR_WEIGHTS = "linear weights"
@@ -72,7 +72,7 @@ class ImageStichOverlapKernel(OverlapKernelBase):
72
72
  self._stitching_axis = stitching_axis
73
73
  self._overlap_size = abs(overlap_size)
74
74
  self._frame_unstitched_axis_size = frame_unstitched_axis_size
75
- self._stitching_strategy = OverlapStitchingStrategy.from_value(stitching_strategy)
75
+ self._stitching_strategy = OverlapStitchingStrategy(stitching_strategy)
76
76
  self._weights_img_1 = None
77
77
  self._weights_img_2 = None
78
78
  if extra_params is None:
@@ -391,7 +391,7 @@ def check_overlaps(frames: Union[tuple, numpy.ndarray], positions: tuple, axis:
391
391
 
392
392
  :return: (tested_bounding_box, bounding_boxes_to_test)
393
393
  """
394
- my_bounding_boxes = {bb_index: bb for bb_index, bb in enumerate(my_bounding_boxes)} # noqa: C416
394
+ my_bounding_boxes = {bb_index: bb for bb_index, bb in enumerate(my_bounding_boxes)}
395
395
  bounding_boxes = dict(
396
396
  filter(
397
397
  lambda pair: pair[0] not in (index - 1, index, index + 1),
@@ -1,13 +1,13 @@
1
+ from enum import Enum
1
2
  import numpy
2
- from silx.utils.enum import Enum as _Enum
3
3
 
4
4
 
5
- class SampleSide(_Enum):
5
+ class SampleSide(Enum):
6
6
  LEFT = "left"
7
7
  RIGHT = "right"
8
8
 
9
9
 
10
- class Method(_Enum):
10
+ class Method(Enum):
11
11
  MEAN = "mean"
12
12
  MEDIAN = "median"
13
13
 
@@ -28,8 +28,8 @@ def normalize_frame(
28
28
  raise TypeError(f"Frame is expected to be a 2D numpy array.")
29
29
  if frame.ndim != 2:
30
30
  raise TypeError(f"Frame is expected to be a 2D numpy array. Get {frame.ndim}D")
31
- side = SampleSide.from_value(side)
32
- method = Method.from_value(method)
31
+ side = SampleSide(side)
32
+ method = Method(method)
33
33
 
34
34
  if frame.shape[1] < sample_width + margin_before_sample:
35
35
  raise ValueError(
@@ -2,6 +2,7 @@ import logging
2
2
  import numpy
3
3
  import os
4
4
  import h5py
5
+ import pint
5
6
  from typing import Union
6
7
  from nabu.stitching.config import PostProcessedSingleAxisStitchingConfiguration
7
8
  from nabu.stitching.alignment import AlignmentAxis1
@@ -15,7 +16,6 @@ from tomoscan.volumebase import VolumeBase
15
16
  from tomoscan.esrf.volume import HDF5Volume
16
17
  from collections.abc import Iterable
17
18
  from contextlib import AbstractContextManager
18
- from pyunitsystem.metricsystem import MetricSystem
19
19
  from nabu.stitching.config import (
20
20
  KEY_IMG_REG_METHOD,
21
21
  )
@@ -25,6 +25,8 @@ from .single_axis import SingleAxisStitcher
25
25
 
26
26
  _logger = logging.getLogger(__name__)
27
27
 
28
+ _ureg = pint.get_application_registry()
29
+
28
30
 
29
31
  class FlippingValueError(ValueError):
30
32
  pass
@@ -266,7 +268,7 @@ class PostProcessingStitching(SingleAxisStitcher):
266
268
  axis_N_pos_px = []
267
269
  for volume, pos_in_mm in zip(self.series, pos_as_mm):
268
270
  voxel_size_m = self.configuration.voxel_size or volume.voxel_size
269
- axis_N_pos_px.append((pos_in_mm / MetricSystem.MILLIMETER.value) / voxel_size_m[0])
271
+ axis_N_pos_px.append((pos_in_mm * _ureg.millimeter).to_base_units().magnitude / voxel_size_m[0])
270
272
  return axis_N_pos_px
271
273
  else:
272
274
  # deduce from motor position and pixel size
@@ -548,7 +550,7 @@ class _RawDatasetsContext(AbstractContextManager):
548
550
  return success
549
551
 
550
552
  def add_padding(self, data: Union[h5py.Dataset, numpy.ndarray], axis_1_dim, alignment: AlignmentAxis1):
551
- alignment = AlignmentAxis1.from_value(alignment)
553
+ alignment = AlignmentAxis1(alignment)
552
554
  if alignment is AlignmentAxis1.BACK:
553
555
  axis_1_pad_width = (axis_1_dim - data.shape[1], 0)
554
556
  elif alignment is AlignmentAxis1.CENTER: