nabu 2025.1.0.dev14__py3-none-any.whl → 2025.1.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doc/doc_config.py +32 -0
- nabu/__init__.py +1 -1
- nabu/app/cast_volume.py +9 -1
- nabu/app/cli_configs.py +80 -3
- nabu/app/estimate_motion.py +54 -0
- nabu/app/multicor.py +2 -4
- nabu/app/pcaflats.py +116 -0
- nabu/app/reconstruct.py +1 -7
- nabu/app/reduce_dark_flat.py +5 -2
- nabu/estimation/cor.py +1 -1
- nabu/estimation/motion.py +557 -0
- nabu/estimation/tests/test_motion_estimation.py +471 -0
- nabu/estimation/tilt.py +1 -1
- nabu/estimation/translation.py +47 -1
- nabu/io/cast_volume.py +100 -13
- nabu/io/reader.py +32 -1
- nabu/io/tests/test_remove_volume.py +152 -0
- nabu/pipeline/config_validators.py +42 -43
- nabu/pipeline/estimators.py +255 -0
- nabu/pipeline/fullfield/chunked.py +67 -43
- nabu/pipeline/fullfield/chunked_cuda.py +5 -2
- nabu/pipeline/fullfield/nabu_config.py +20 -14
- nabu/pipeline/fullfield/processconfig.py +17 -3
- nabu/pipeline/fullfield/reconstruction.py +4 -1
- nabu/pipeline/params.py +12 -0
- nabu/pipeline/tests/test_estimators.py +240 -3
- nabu/preproc/ccd.py +53 -3
- nabu/preproc/flatfield.py +306 -1
- nabu/preproc/shift.py +3 -1
- nabu/preproc/tests/test_pcaflats.py +154 -0
- nabu/processing/rotation_cuda.py +3 -1
- nabu/processing/tests/test_rotation.py +4 -2
- nabu/reconstruction/astra.py +245 -0
- nabu/reconstruction/fbp.py +7 -0
- nabu/reconstruction/fbp_base.py +31 -7
- nabu/reconstruction/fbp_opencl.py +8 -0
- nabu/reconstruction/filtering_opencl.py +2 -0
- nabu/reconstruction/mlem.py +47 -13
- nabu/reconstruction/tests/test_filtering.py +13 -2
- nabu/reconstruction/tests/test_mlem.py +91 -62
- nabu/resources/dataset_analyzer.py +144 -20
- nabu/resources/nxflatfield.py +101 -35
- nabu/resources/tests/test_nxflatfield.py +1 -1
- nabu/resources/utils.py +16 -10
- nabu/stitching/alignment.py +7 -7
- nabu/stitching/config.py +22 -20
- nabu/stitching/definitions.py +2 -2
- nabu/stitching/overlap.py +4 -4
- nabu/stitching/sample_normalization.py +5 -5
- nabu/stitching/stitcher/post_processing.py +5 -3
- nabu/stitching/stitcher/pre_processing.py +24 -20
- nabu/stitching/tests/test_config.py +3 -3
- nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
- nabu/stitching/tests/test_z_postprocessing_stitching.py +2 -2
- nabu/stitching/tests/test_z_preprocessing_stitching.py +23 -20
- nabu/stitching/utils/utils.py +7 -7
- nabu/testutils.py +1 -4
- nabu/utils.py +13 -0
- {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/METADATA +3 -4
- {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/RECORD +64 -57
- {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/WHEEL +1 -1
- {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/entry_points.txt +2 -1
- nabu/app/correct_rot.py +0 -62
- {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/licenses/LICENSE +0 -0
- {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/top_level.txt +0 -0
nabu/resources/nxflatfield.py
CHANGED
@@ -6,6 +6,8 @@ from silx.io import get_data
|
|
6
6
|
from tomoscan.framereducer.reducedframesinfos import ReducedFramesInfos
|
7
7
|
from tomoscan.esrf.scan.nxtomoscan import NXtomoScan
|
8
8
|
from ..utils import check_supported, is_writeable
|
9
|
+
from ..preproc.flatfield import PCAFlatsDecomposer
|
10
|
+
from ..io.reader import NXDarksFlats
|
9
11
|
|
10
12
|
|
11
13
|
def get_frame_possible_urls(dataset_info, user_dir, output_dir):
|
@@ -22,7 +24,7 @@ def get_frame_possible_urls(dataset_info, user_dir, output_dir):
|
|
22
24
|
Output processing directory
|
23
25
|
"""
|
24
26
|
|
25
|
-
frame_types = ["flats", "darks"]
|
27
|
+
frame_types = ["flats", "darks", "pcaflats"]
|
26
28
|
h5scan = dataset_info.dataset_scanner # tomoscan object
|
27
29
|
|
28
30
|
def make_dataurl(dirname, frame_type):
|
@@ -32,8 +34,10 @@ def get_frame_possible_urls(dataset_info, user_dir, output_dir):
|
|
32
34
|
|
33
35
|
if frame_type == "flats":
|
34
36
|
dataurl_default_template = h5scan.REDUCED_FLATS_DATAURLS[0]
|
35
|
-
|
37
|
+
elif frame_type == "darks":
|
36
38
|
dataurl_default_template = h5scan.REDUCED_DARKS_DATAURLS[0]
|
39
|
+
elif frame_type == "pcaflats":
|
40
|
+
dataurl_default_template = h5scan.PCA_FLATS_DATAURLS[0]
|
37
41
|
|
38
42
|
rel_file_path = dataurl_default_template.file_path().format(scan_prefix=h5scan.get_dataset_basename())
|
39
43
|
return DataUrl(
|
@@ -149,8 +153,94 @@ def data_url_exists(data_url):
|
|
149
153
|
return group_exists
|
150
154
|
|
151
155
|
|
156
|
+
def _compute_and_save_reduced_frames(flatfield_mode, dataset_info, reduced_frames_urls):
|
157
|
+
if flatfield_mode == "pca":
|
158
|
+
dfreader = NXDarksFlats(dataset_info.location)
|
159
|
+
darks = np.concatenate([d for d in dfreader.get_raw_darks()], axis=0)
|
160
|
+
flats = np.concatenate([f for f in dfreader.get_raw_flats()], axis=0)
|
161
|
+
pcaflats_darks = PCAFlatsDecomposer(flats, darks)
|
162
|
+
|
163
|
+
# Get "where to write". tomoscan expects a DataUrl
|
164
|
+
pcaflats_dir_url = reduced_frames_urls.get("user", None)
|
165
|
+
if pcaflats_dir_url is not None:
|
166
|
+
output_url = pcaflats_dir_url
|
167
|
+
elif is_writeable(os.path.dirname(reduced_frames_urls["dataset"]["flats"].file_path())):
|
168
|
+
output_url = reduced_frames_urls["dataset"]
|
169
|
+
else:
|
170
|
+
output_url = reduced_frames_urls["output"]
|
171
|
+
pcaflats_darks.save_decomposition(
|
172
|
+
path=output_url["pcaflats"].file_path(), entry=output_url["pcaflats"].data_path().strip("/").split("/")[0]
|
173
|
+
)
|
174
|
+
dataset_info.logger.info("PCA flats computed and written at %s" % (output_url["pcaflats"].file_path()))
|
175
|
+
|
176
|
+
# Update dataset_info with pca flats and dark
|
177
|
+
dataset_info.darks = {0: pcaflats_darks.dark}
|
178
|
+
flats = {0: pcaflats_darks.mean}
|
179
|
+
for k in range(len(pcaflats_darks.components)):
|
180
|
+
flats.update({k + 1: pcaflats_darks.components[k]})
|
181
|
+
dataset_info.flats = flats
|
182
|
+
else:
|
183
|
+
try:
|
184
|
+
dataset_info.flats = dataset_info.get_reduced_flats()
|
185
|
+
dataset_info.darks = dataset_info.get_reduced_darks()
|
186
|
+
except FileNotFoundError:
|
187
|
+
msg = "Could not find any flats and/or darks"
|
188
|
+
raise FileNotFoundError(msg)
|
189
|
+
_, flats_info, darks_info = save_reduced_frames(
|
190
|
+
dataset_info, {"darks": dataset_info.darks, "flats": dataset_info.flats}, reduced_frames_urls
|
191
|
+
)
|
192
|
+
dataset_info.flats_srcurrent = flats_info.machine_electric_current
|
193
|
+
|
194
|
+
|
195
|
+
def _load_existing_flatfields(dataset_info, reduced_frames_urls, frames_types, where_to_load_from):
|
196
|
+
if "pcaflats" not in frames_types:
|
197
|
+
reduced_frames_with_info = {}
|
198
|
+
for frame_type in frames_types:
|
199
|
+
reduced_frames_with_info[frame_type] = tomoscan_load_reduced_frames(
|
200
|
+
dataset_info, frame_type, reduced_frames_urls[where_to_load_from][frame_type]
|
201
|
+
)
|
202
|
+
dataset_info.logger.info(
|
203
|
+
"Loaded %s from %s" % (frame_type, reduced_frames_urls[where_to_load_from][frame_type].file_path())
|
204
|
+
)
|
205
|
+
red_frames_dict, red_frames_info = reduced_frames_with_info[frame_type]
|
206
|
+
setattr(
|
207
|
+
dataset_info,
|
208
|
+
frame_type,
|
209
|
+
{k: get_data(red_frames_dict[k]) for k in red_frames_dict},
|
210
|
+
)
|
211
|
+
if frame_type == "flats":
|
212
|
+
dataset_info.flats_srcurrent = red_frames_info.machine_electric_current
|
213
|
+
else:
|
214
|
+
df_path = reduced_frames_urls[where_to_load_from]["pcaflats"].file_path()
|
215
|
+
entry = reduced_frames_urls[where_to_load_from]["pcaflats"].data_path()
|
216
|
+
|
217
|
+
# Update dark
|
218
|
+
dark_url = DataUrl(f"silx://{df_path}?{entry}/dark")
|
219
|
+
dark = get_data(dark_url)
|
220
|
+
setattr(
|
221
|
+
dataset_info,
|
222
|
+
"dark",
|
223
|
+
{0: dark},
|
224
|
+
)
|
225
|
+
# Update flats with principal compenents
|
226
|
+
# Take mean as first comp., mask as second, flats thereafter
|
227
|
+
flats_url = DataUrl(f"silx://{df_path}?{entry}/p_components")
|
228
|
+
mean_url = DataUrl(f"silx://{df_path}?{entry}/p_mean")
|
229
|
+
flats = get_data(flats_url)
|
230
|
+
mean = get_data(mean_url)
|
231
|
+
flats = np.concatenate([mean[np.newaxis], flats], axis=0)
|
232
|
+
setattr(
|
233
|
+
dataset_info,
|
234
|
+
"flats",
|
235
|
+
{k: flats[k] for k in range(len(flats))},
|
236
|
+
)
|
237
|
+
dataset_info.logger.info("Loaded %s from %s" % ("PCA darks/flats", df_path))
|
238
|
+
|
239
|
+
|
152
240
|
# pylint: disable=E1136
|
153
|
-
def update_dataset_info_flats_darks(
|
241
|
+
def update_dataset_info_flats_darks(
|
242
|
+
dataset_info, flatfield_mode, loading_mode="load_if_present", output_dir=None, darks_flats_dir=None
|
243
|
+
):
|
154
244
|
"""
|
155
245
|
Update a DatasetAnalyzer object with reduced flats/darks (hereafter "reduced frames").
|
156
246
|
|
@@ -170,23 +260,14 @@ def update_dataset_info_flats_darks(dataset_info, flatfield_mode, output_dir=Non
|
|
170
260
|
if flatfield_mode is False:
|
171
261
|
return
|
172
262
|
|
173
|
-
|
263
|
+
if flatfield_mode == "pca":
|
264
|
+
frames_types = ["pcaflats"]
|
265
|
+
else:
|
266
|
+
frames_types = ["darks", "flats"]
|
174
267
|
reduced_frames_urls = get_frame_possible_urls(dataset_info, darks_flats_dir, output_dir)
|
175
268
|
|
176
|
-
|
177
|
-
|
178
|
-
dataset_info.flats = dataset_info.get_reduced_flats()
|
179
|
-
dataset_info.darks = dataset_info.get_reduced_darks()
|
180
|
-
except FileNotFoundError:
|
181
|
-
msg = "Could not find any flats and/or darks"
|
182
|
-
raise FileNotFoundError(msg)
|
183
|
-
_, flats_info, darks_info = save_reduced_frames(
|
184
|
-
dataset_info, {"darks": dataset_info.darks, "flats": dataset_info.flats}, reduced_frames_urls
|
185
|
-
)
|
186
|
-
dataset_info.flats_srcurrent = flats_info.machine_electric_current
|
187
|
-
|
188
|
-
if flatfield_mode == "force-compute":
|
189
|
-
_compute_and_save_reduced_frames()
|
269
|
+
if loading_mode == "force-compute":
|
270
|
+
_compute_and_save_reduced_frames(flatfield_mode, dataset_info, reduced_frames_urls)
|
190
271
|
return
|
191
272
|
|
192
273
|
def _can_load_from(folder_type):
|
@@ -206,21 +287,6 @@ def update_dataset_info_flats_darks(dataset_info, flatfield_mode, output_dir=Non
|
|
206
287
|
raise ValueError("Could not load darks/flats (using 'force-load')")
|
207
288
|
|
208
289
|
if where_to_load_from is not None:
|
209
|
-
|
210
|
-
for frame_type in frames_types:
|
211
|
-
reduced_frames_with_info[frame_type] = tomoscan_load_reduced_frames(
|
212
|
-
dataset_info, frame_type, reduced_frames_urls[where_to_load_from][frame_type]
|
213
|
-
)
|
214
|
-
dataset_info.logger.info(
|
215
|
-
"Loaded %s from %s" % (frame_type, reduced_frames_urls[where_to_load_from][frame_type].file_path())
|
216
|
-
)
|
217
|
-
red_frames_dict, red_frames_info = reduced_frames_with_info[frame_type]
|
218
|
-
setattr(
|
219
|
-
dataset_info,
|
220
|
-
frame_type,
|
221
|
-
{k: get_data(red_frames_dict[k]) for k in red_frames_dict},
|
222
|
-
)
|
223
|
-
if frame_type == "flats":
|
224
|
-
dataset_info.flats_srcurrent = red_frames_info.machine_electric_current
|
290
|
+
_load_existing_flatfields(dataset_info, reduced_frames_urls, frames_types, where_to_load_from)
|
225
291
|
else:
|
226
|
-
_compute_and_save_reduced_frames()
|
292
|
+
_compute_and_save_reduced_frames(flatfield_mode, dataset_info, reduced_frames_urls)
|
@@ -80,7 +80,7 @@ class TestNXFlatField:
|
|
80
80
|
output_dir = self.params.get("output_dir", None)
|
81
81
|
if output_dir is not None:
|
82
82
|
output_dir = output_dir.format(tempdir=self.tempdir)
|
83
|
-
update_dataset_info_flats_darks(dataset_info, True, output_dir=output_dir)
|
83
|
+
update_dataset_info_flats_darks(dataset_info, True, loading_mode="load_if_present", output_dir=output_dir)
|
84
84
|
# After reduction (median/mean), the flats/darks are located in another file.
|
85
85
|
# median(series_1) goes to entry/flats/idx1, mean(series_2) goes to entry/flats/idx2, etc.
|
86
86
|
assert set(dataset_info.flats.keys()) == set(s.start for s in self.params["flats_pos"]) # noqa: C401
|
nabu/resources/utils.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
1
|
from ast import literal_eval
|
2
2
|
import numpy as np
|
3
|
+
import pint
|
3
4
|
from psutil import virtual_memory, cpu_count
|
4
|
-
|
5
|
-
|
5
|
+
|
6
|
+
_ureg = pint.get_application_registry()
|
6
7
|
|
7
8
|
|
8
9
|
def get_values_from_file(fname, n_values=None, shape=None, sep=None, any_size=False):
|
@@ -163,12 +164,17 @@ def get_quantities_and_units(string, sep=";"):
|
|
163
164
|
value, unit = value_and_unit.split()
|
164
165
|
val = float(value)
|
165
166
|
# Convert to SI
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
167
|
+
if unit.lower() == "kev":
|
168
|
+
current_unit = _ureg.keV
|
169
|
+
elif unit.lower() == "ev":
|
170
|
+
current_unit = _ureg.eV
|
171
|
+
else:
|
172
|
+
current_unit = _ureg(unit)
|
173
|
+
# handle energies (to move to keV)
|
174
|
+
if _ureg.keV.dimensionality == current_unit.dimensionality:
|
175
|
+
result[quantity_name] = (val * current_unit).to(_ureg.keV).magnitude
|
176
|
+
elif _ureg.meter.dimensionality == current_unit.dimensionality:
|
177
|
+
result[quantity_name] = (val * current_unit).to_base_units().magnitude
|
178
|
+
else:
|
179
|
+
raise ValueError(f"Cannot convert: {unit}")
|
174
180
|
return result
|
nabu/stitching/alignment.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1
|
+
from enum import Enum
|
1
2
|
import h5py
|
2
3
|
import numpy
|
3
4
|
from typing import Union
|
4
|
-
from silx.utils.enum import Enum as _Enum
|
5
5
|
|
6
6
|
|
7
|
-
class AlignmentAxis2(
|
7
|
+
class AlignmentAxis2(Enum):
|
8
8
|
"""Specific alignment named to help users orienting themself with specific name"""
|
9
9
|
|
10
10
|
CENTER = "center"
|
@@ -12,7 +12,7 @@ class AlignmentAxis2(_Enum):
|
|
12
12
|
RIGTH = "right"
|
13
13
|
|
14
14
|
|
15
|
-
class AlignmentAxis1(
|
15
|
+
class AlignmentAxis1(Enum):
|
16
16
|
"""Specific alignment named to help users orienting themself with specific name"""
|
17
17
|
|
18
18
|
FRONT = "front"
|
@@ -20,7 +20,7 @@ class AlignmentAxis1(_Enum):
|
|
20
20
|
BACK = "back"
|
21
21
|
|
22
22
|
|
23
|
-
class _Alignment(
|
23
|
+
class _Alignment(Enum):
|
24
24
|
"""Internal alignment to be used for 2D alignment"""
|
25
25
|
|
26
26
|
LOWER_BOUNDARY = "lower boundary"
|
@@ -29,7 +29,7 @@ class _Alignment(_Enum):
|
|
29
29
|
|
30
30
|
@classmethod
|
31
31
|
def from_value(cls, value):
|
32
|
-
# cast the AlignmentAxis1 and AlignmentAxis2 values to fit the generic definition
|
32
|
+
# cast the AlignmentAxis1 and AlignmentAxis2 values to fit the generic definition.
|
33
33
|
if value in ("front", "left", AlignmentAxis1.FRONT, AlignmentAxis2.LEFT):
|
34
34
|
return _Alignment.LOWER_BOUNDARY
|
35
35
|
elif value in ("back", "right", AlignmentAxis1.BACK, AlignmentAxis2.RIGTH):
|
@@ -37,7 +37,7 @@ class _Alignment(_Enum):
|
|
37
37
|
elif value in (AlignmentAxis1.CENTER, AlignmentAxis2.CENTER):
|
38
38
|
return _Alignment.CENTER
|
39
39
|
else:
|
40
|
-
return super().
|
40
|
+
return super().__new__(cls, value)
|
41
41
|
|
42
42
|
|
43
43
|
def align_frame(
|
@@ -103,7 +103,7 @@ def align_horizontally(data: numpy.ndarray, alignment: AlignmentAxis2, new_width
|
|
103
103
|
:param HAlignment alignment: alignment strategy
|
104
104
|
:param int new_width: output data width
|
105
105
|
"""
|
106
|
-
alignment = AlignmentAxis2
|
106
|
+
alignment = AlignmentAxis2(alignment).value
|
107
107
|
return align_frame(
|
108
108
|
data=data, alignment=alignment, new_aligned_axis_size=new_width, pad_mode=pad_mode, alignment_axis=1
|
109
109
|
)
|
nabu/stitching/config.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1
|
+
import pint
|
1
2
|
from math import ceil
|
2
3
|
from typing import Optional, Union
|
3
4
|
from collections.abc import Iterable, Sized
|
4
5
|
from dataclasses import dataclass
|
5
6
|
import numpy
|
6
|
-
from pyunitsystem.metricsystem import MetricSystem
|
7
7
|
from nxtomo.paths import nxtomo
|
8
8
|
from tomoscan.factory import Factory
|
9
9
|
from tomoscan.identifier import VolumeIdentifier, ScanIdentifier
|
@@ -18,6 +18,8 @@ from .utils.utils import ShiftAlgorithm
|
|
18
18
|
from .definitions import StitchingType
|
19
19
|
from .alignment import AlignmentAxis1, AlignmentAxis2
|
20
20
|
|
21
|
+
_ureg = pint.get_application_registry()
|
22
|
+
|
21
23
|
# ruff: noqa: S105
|
22
24
|
|
23
25
|
KEY_IMG_REG_METHOD = "img_reg_method"
|
@@ -220,7 +222,7 @@ def str_to_shifts(my_str: Optional[str]) -> Union[str, tuple]:
|
|
220
222
|
if my_str == "":
|
221
223
|
return None
|
222
224
|
try:
|
223
|
-
shift = ShiftAlgorithm
|
225
|
+
shift = ShiftAlgorithm(my_str)
|
224
226
|
except ValueError:
|
225
227
|
shifts_as_str = filter(None, my_str.replace(";", ",").split(","))
|
226
228
|
return [float(shift) for shift in shifts_as_str]
|
@@ -336,7 +338,7 @@ class NormalizationBySample:
|
|
336
338
|
|
337
339
|
@method.setter
|
338
340
|
def method(self, method: Union[Method, str]) -> None:
|
339
|
-
self._method = Method
|
341
|
+
self._method = Method(method)
|
340
342
|
|
341
343
|
@property
|
342
344
|
def margin(self) -> int:
|
@@ -353,7 +355,7 @@ class NormalizationBySample:
|
|
353
355
|
|
354
356
|
@side.setter
|
355
357
|
def side(self, side: Union[SampleSide, str]):
|
356
|
-
self._side = SampleSide
|
358
|
+
self._side = SampleSide(side)
|
357
359
|
|
358
360
|
@property
|
359
361
|
def width(self) -> int:
|
@@ -548,12 +550,12 @@ class StitchingConfiguration:
|
|
548
550
|
STITCHING_SECTION: {
|
549
551
|
STITCHING_TYPE_FIELD: {
|
550
552
|
"default": StitchingType.Z_PREPROC.value,
|
551
|
-
"help": f"stitching to be applied. Must be in {StitchingType
|
553
|
+
"help": f"stitching to be applied. Must be in {[st.value for st in StitchingType]}",
|
552
554
|
"type": "required",
|
553
555
|
},
|
554
556
|
STITCHING_STRATEGY_FIELD: {
|
555
557
|
"default": "cosinus weights",
|
556
|
-
"help": f"Policy to apply to compute the overlap area. Must be in {OverlapStitchingStrategy
|
558
|
+
"help": f"Policy to apply to compute the overlap area. Must be in {[ov.value for ov in OverlapStitchingStrategy]}.",
|
557
559
|
"type": "required",
|
558
560
|
},
|
559
561
|
CROSS_CORRELATION_SLICE_FIELD: {
|
@@ -633,7 +635,7 @@ class StitchingConfiguration:
|
|
633
635
|
},
|
634
636
|
ALIGNMENT_AXIS_2_FIELD: {
|
635
637
|
"default": "center",
|
636
|
-
"help": f"In case frame have different frame widths how to align them (so along volume axis 2). Valid keys are {AlignmentAxis2
|
638
|
+
"help": f"In case frame have different frame widths how to align them (so along volume axis 2). Valid keys are {[aa.value for aa in AlignmentAxis2]}",
|
637
639
|
"type": "advanced",
|
638
640
|
},
|
639
641
|
PAD_MODE_FIELD: {
|
@@ -755,7 +757,7 @@ class StitchingConfiguration:
|
|
755
757
|
AXIS_2_POS_PX: _cast_shift_to_str(self.axis_2_pos_px),
|
756
758
|
AXIS_2_POS_MM: _cast_shift_to_str(self.axis_2_pos_mm),
|
757
759
|
AXIS_2_PARAMS: _dict_to_str(self.axis_2_params or {}),
|
758
|
-
STITCHING_STRATEGY_FIELD: OverlapStitchingStrategy
|
760
|
+
STITCHING_STRATEGY_FIELD: OverlapStitchingStrategy(self.stitching_strategy).value,
|
759
761
|
FLIP_UD: self.flip_ud,
|
760
762
|
FLIP_LR: self.flip_lr,
|
761
763
|
RESCALE_FRAMES: self.rescale_frames,
|
@@ -934,7 +936,7 @@ class PreProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigurat
|
|
934
936
|
if self.pixel_size is None:
|
935
937
|
pixel_size_mm = ""
|
936
938
|
else:
|
937
|
-
pixel_size_mm = self.pixel_size *
|
939
|
+
pixel_size_mm = (self.pixel_size * _ureg.meter).to(_ureg.millimeter).magnitude
|
938
940
|
return concatenate_dict(
|
939
941
|
super().to_dict(),
|
940
942
|
{
|
@@ -998,10 +1000,10 @@ class PreProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigurat
|
|
998
1000
|
if pixel_size == "":
|
999
1001
|
pixel_size = None
|
1000
1002
|
else:
|
1001
|
-
pixel_size = float(pixel_size)
|
1003
|
+
pixel_size = (float(pixel_size) * _ureg.millimeter).to_base_units().magnitude
|
1002
1004
|
|
1003
1005
|
return cls(
|
1004
|
-
stitching_strategy=OverlapStitchingStrategy
|
1006
|
+
stitching_strategy=OverlapStitchingStrategy(
|
1005
1007
|
config[STITCHING_SECTION].get(
|
1006
1008
|
STITCHING_STRATEGY_FIELD,
|
1007
1009
|
OverlapStitchingStrategy.COSINUS_WEIGHTS,
|
@@ -1042,7 +1044,7 @@ class PreProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigurat
|
|
1042
1044
|
config[STITCHING_SECTION].get(STITCHING_KERNELS_EXTRA_PARAMS, {}),
|
1043
1045
|
)
|
1044
1046
|
),
|
1045
|
-
alignment_axis_2=AlignmentAxis2
|
1047
|
+
alignment_axis_2=AlignmentAxis2(
|
1046
1048
|
config[STITCHING_SECTION].get(ALIGNMENT_AXIS_2_FIELD, AlignmentAxis2.CENTER)
|
1047
1049
|
),
|
1048
1050
|
pad_mode=config[STITCHING_SECTION].get(PAD_MODE_FIELD, "constant"),
|
@@ -1163,11 +1165,11 @@ class PostProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigura
|
|
1163
1165
|
if voxel_size == "":
|
1164
1166
|
voxel_size = None
|
1165
1167
|
else:
|
1166
|
-
voxel_size = float(voxel_size) *
|
1168
|
+
voxel_size = (float(voxel_size) * _ureg.millimeter).to_base_units().magnitude
|
1167
1169
|
|
1168
1170
|
# on the next section the one with a default value qre the optional one
|
1169
1171
|
return cls(
|
1170
|
-
stitching_strategy=OverlapStitchingStrategy
|
1172
|
+
stitching_strategy=OverlapStitchingStrategy(
|
1171
1173
|
config[STITCHING_SECTION].get(
|
1172
1174
|
STITCHING_STRATEGY_FIELD,
|
1173
1175
|
OverlapStitchingStrategy.COSINUS_WEIGHTS,
|
@@ -1198,10 +1200,10 @@ class PostProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigura
|
|
1198
1200
|
config[STITCHING_SECTION].get(STITCHING_KERNELS_EXTRA_PARAMS, {}),
|
1199
1201
|
)
|
1200
1202
|
),
|
1201
|
-
alignment_axis_1=AlignmentAxis1
|
1203
|
+
alignment_axis_1=AlignmentAxis1(
|
1202
1204
|
config[STITCHING_SECTION].get(ALIGNMENT_AXIS_1_FIELD, AlignmentAxis1.CENTER)
|
1203
1205
|
),
|
1204
|
-
alignment_axis_2=AlignmentAxis2
|
1206
|
+
alignment_axis_2=AlignmentAxis2(
|
1205
1207
|
config[STITCHING_SECTION].get(ALIGNMENT_AXIS_2_FIELD, AlignmentAxis2.CENTER)
|
1206
1208
|
),
|
1207
1209
|
pad_mode=config[STITCHING_SECTION].get(PAD_MODE_FIELD, "constant"),
|
@@ -1215,7 +1217,7 @@ class PostProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigura
|
|
1215
1217
|
if self.voxel_size is None:
|
1216
1218
|
voxel_size_mm = ""
|
1217
1219
|
else:
|
1218
|
-
voxel_size_mm = numpy.array(self.voxel_size
|
1220
|
+
voxel_size_mm = numpy.array((self.voxel_size * _ureg.meter).to(_ureg.millimeter).magnitude)
|
1219
1221
|
|
1220
1222
|
return concatenate_dict(
|
1221
1223
|
super().to_dict(),
|
@@ -1250,7 +1252,7 @@ class PostProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigura
|
|
1250
1252
|
STITCHING_SECTION: {
|
1251
1253
|
ALIGNMENT_AXIS_1_FIELD: {
|
1252
1254
|
"default": "center",
|
1253
|
-
"help": f"alignment to apply over axis 1 if needed. Valid values are {AlignmentAxis1
|
1255
|
+
"help": f"alignment to apply over axis 1 if needed. Valid values are {[aa for aa in AlignmentAxis1]}",
|
1254
1256
|
"type": "advanced",
|
1255
1257
|
}
|
1256
1258
|
},
|
@@ -1281,7 +1283,7 @@ def dict_to_config_obj(config: dict):
|
|
1281
1283
|
if stitching_type is None:
|
1282
1284
|
raise ValueError("Unable to find stitching type from config dict")
|
1283
1285
|
else:
|
1284
|
-
stitching_type = StitchingType
|
1286
|
+
stitching_type = StitchingType(stitching_type)
|
1285
1287
|
if stitching_type is StitchingType.Z_POSTPROC:
|
1286
1288
|
return PostProcessedZStitchingConfiguration.from_dict(config)
|
1287
1289
|
elif stitching_type is StitchingType.Z_PREPROC:
|
@@ -1304,7 +1306,7 @@ def get_default_stitching_config(stitching_type: Optional[Union[StitchingType, s
|
|
1304
1306
|
if stitching_type is None:
|
1305
1307
|
return concatenate_dict(z_postproc_stitching_config, z_preproc_stitching_config)
|
1306
1308
|
|
1307
|
-
stitching_type = StitchingType
|
1309
|
+
stitching_type = StitchingType(stitching_type)
|
1308
1310
|
if stitching_type is StitchingType.Z_POSTPROC:
|
1309
1311
|
return z_postproc_stitching_config
|
1310
1312
|
elif stitching_type is StitchingType.Z_PREPROC:
|
nabu/stitching/definitions.py
CHANGED
nabu/stitching/overlap.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import numpy
|
2
2
|
import logging
|
3
3
|
from typing import Optional, Union
|
4
|
-
from
|
4
|
+
from enum import Enum
|
5
5
|
from nabu.misc import fourier_filters
|
6
6
|
from scipy.fft import rfftn as local_fftn
|
7
7
|
from scipy.fft import irfftn as local_ifftn
|
@@ -10,7 +10,7 @@ from tomoscan.utils.geometry import BoundingBox1D
|
|
10
10
|
_logger = logging.getLogger(__name__)
|
11
11
|
|
12
12
|
|
13
|
-
class OverlapStitchingStrategy(
|
13
|
+
class OverlapStitchingStrategy(Enum):
|
14
14
|
MEAN = "mean"
|
15
15
|
COSINUS_WEIGHTS = "cosinus weights"
|
16
16
|
LINEAR_WEIGHTS = "linear weights"
|
@@ -72,7 +72,7 @@ class ImageStichOverlapKernel(OverlapKernelBase):
|
|
72
72
|
self._stitching_axis = stitching_axis
|
73
73
|
self._overlap_size = abs(overlap_size)
|
74
74
|
self._frame_unstitched_axis_size = frame_unstitched_axis_size
|
75
|
-
self._stitching_strategy = OverlapStitchingStrategy
|
75
|
+
self._stitching_strategy = OverlapStitchingStrategy(stitching_strategy)
|
76
76
|
self._weights_img_1 = None
|
77
77
|
self._weights_img_2 = None
|
78
78
|
if extra_params is None:
|
@@ -391,7 +391,7 @@ def check_overlaps(frames: Union[tuple, numpy.ndarray], positions: tuple, axis:
|
|
391
391
|
|
392
392
|
:return: (tested_bounding_box, bounding_boxes_to_test)
|
393
393
|
"""
|
394
|
-
my_bounding_boxes = {bb_index: bb for bb_index, bb in enumerate(my_bounding_boxes)}
|
394
|
+
my_bounding_boxes = {bb_index: bb for bb_index, bb in enumerate(my_bounding_boxes)}
|
395
395
|
bounding_boxes = dict(
|
396
396
|
filter(
|
397
397
|
lambda pair: pair[0] not in (index - 1, index, index + 1),
|
@@ -1,13 +1,13 @@
|
|
1
|
+
from enum import Enum
|
1
2
|
import numpy
|
2
|
-
from silx.utils.enum import Enum as _Enum
|
3
3
|
|
4
4
|
|
5
|
-
class SampleSide(
|
5
|
+
class SampleSide(Enum):
|
6
6
|
LEFT = "left"
|
7
7
|
RIGHT = "right"
|
8
8
|
|
9
9
|
|
10
|
-
class Method(
|
10
|
+
class Method(Enum):
|
11
11
|
MEAN = "mean"
|
12
12
|
MEDIAN = "median"
|
13
13
|
|
@@ -28,8 +28,8 @@ def normalize_frame(
|
|
28
28
|
raise TypeError(f"Frame is expected to be a 2D numpy array.")
|
29
29
|
if frame.ndim != 2:
|
30
30
|
raise TypeError(f"Frame is expected to be a 2D numpy array. Get {frame.ndim}D")
|
31
|
-
side = SampleSide
|
32
|
-
method = Method
|
31
|
+
side = SampleSide(side)
|
32
|
+
method = Method(method)
|
33
33
|
|
34
34
|
if frame.shape[1] < sample_width + margin_before_sample:
|
35
35
|
raise ValueError(
|
@@ -2,6 +2,7 @@ import logging
|
|
2
2
|
import numpy
|
3
3
|
import os
|
4
4
|
import h5py
|
5
|
+
import pint
|
5
6
|
from typing import Union
|
6
7
|
from nabu.stitching.config import PostProcessedSingleAxisStitchingConfiguration
|
7
8
|
from nabu.stitching.alignment import AlignmentAxis1
|
@@ -15,7 +16,6 @@ from tomoscan.volumebase import VolumeBase
|
|
15
16
|
from tomoscan.esrf.volume import HDF5Volume
|
16
17
|
from collections.abc import Iterable
|
17
18
|
from contextlib import AbstractContextManager
|
18
|
-
from pyunitsystem.metricsystem import MetricSystem
|
19
19
|
from nabu.stitching.config import (
|
20
20
|
KEY_IMG_REG_METHOD,
|
21
21
|
)
|
@@ -25,6 +25,8 @@ from .single_axis import SingleAxisStitcher
|
|
25
25
|
|
26
26
|
_logger = logging.getLogger(__name__)
|
27
27
|
|
28
|
+
_ureg = pint.get_application_registry()
|
29
|
+
|
28
30
|
|
29
31
|
class FlippingValueError(ValueError):
|
30
32
|
pass
|
@@ -266,7 +268,7 @@ class PostProcessingStitching(SingleAxisStitcher):
|
|
266
268
|
axis_N_pos_px = []
|
267
269
|
for volume, pos_in_mm in zip(self.series, pos_as_mm):
|
268
270
|
voxel_size_m = self.configuration.voxel_size or volume.voxel_size
|
269
|
-
axis_N_pos_px.append((pos_in_mm
|
271
|
+
axis_N_pos_px.append((pos_in_mm * _ureg.millimeter).to_base_units().magnitude / voxel_size_m[0])
|
270
272
|
return axis_N_pos_px
|
271
273
|
else:
|
272
274
|
# deduce from motor position and pixel size
|
@@ -548,7 +550,7 @@ class _RawDatasetsContext(AbstractContextManager):
|
|
548
550
|
return success
|
549
551
|
|
550
552
|
def add_padding(self, data: Union[h5py.Dataset, numpy.ndarray], axis_1_dim, alignment: AlignmentAxis1):
|
551
|
-
alignment = AlignmentAxis1
|
553
|
+
alignment = AlignmentAxis1(alignment)
|
552
554
|
if alignment is AlignmentAxis1.BACK:
|
553
555
|
axis_1_pad_width = (axis_1_dim - data.shape[1], 0)
|
554
556
|
elif alignment is AlignmentAxis1.CENTER:
|