nabu 2024.2.14__py3-none-any.whl → 2025.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doc/doc_config.py +32 -0
- nabu/__init__.py +1 -1
- nabu/app/bootstrap_stitching.py +4 -2
- nabu/app/cast_volume.py +16 -14
- nabu/app/cli_configs.py +102 -9
- nabu/app/compare_volumes.py +1 -1
- nabu/app/composite_cor.py +2 -4
- nabu/app/diag_to_pix.py +5 -6
- nabu/app/diag_to_rot.py +10 -11
- nabu/app/double_flatfield.py +18 -5
- nabu/app/estimate_motion.py +75 -0
- nabu/app/multicor.py +28 -15
- nabu/app/parse_reconstruction_log.py +1 -0
- nabu/app/pcaflats.py +122 -0
- nabu/app/prepare_weights_double.py +1 -2
- nabu/app/reconstruct.py +1 -7
- nabu/app/reconstruct_helical.py +5 -9
- nabu/app/reduce_dark_flat.py +5 -4
- nabu/app/rotate.py +3 -1
- nabu/app/stitching.py +7 -2
- nabu/app/tests/test_reduce_dark_flat.py +2 -2
- nabu/app/validator.py +1 -4
- nabu/cuda/convolution.py +1 -1
- nabu/cuda/fft.py +1 -1
- nabu/cuda/medfilt.py +1 -1
- nabu/cuda/padding.py +1 -1
- nabu/cuda/src/backproj.cu +6 -6
- nabu/cuda/src/cone.cu +4 -0
- nabu/cuda/src/hierarchical_backproj.cu +14 -0
- nabu/cuda/utils.py +2 -2
- nabu/estimation/alignment.py +17 -31
- nabu/estimation/cor.py +27 -33
- nabu/estimation/cor_sino.py +2 -8
- nabu/estimation/focus.py +4 -8
- nabu/estimation/motion.py +557 -0
- nabu/estimation/tests/test_alignment.py +2 -0
- nabu/estimation/tests/test_motion_estimation.py +471 -0
- nabu/estimation/tests/test_tilt.py +1 -1
- nabu/estimation/tilt.py +6 -5
- nabu/estimation/translation.py +47 -1
- nabu/io/cast_volume.py +108 -18
- nabu/io/detector_distortion.py +5 -6
- nabu/io/reader.py +45 -6
- nabu/io/reader_helical.py +5 -4
- nabu/io/tests/test_cast_volume.py +2 -2
- nabu/io/tests/test_readers.py +41 -38
- nabu/io/tests/test_remove_volume.py +152 -0
- nabu/io/tests/test_writers.py +2 -2
- nabu/io/utils.py +8 -4
- nabu/io/writer.py +1 -2
- nabu/misc/fftshift.py +1 -1
- nabu/misc/fourier_filters.py +1 -1
- nabu/misc/histogram.py +1 -1
- nabu/misc/histogram_cuda.py +1 -1
- nabu/misc/padding_base.py +1 -1
- nabu/misc/rotation.py +1 -1
- nabu/misc/rotation_cuda.py +1 -1
- nabu/misc/tests/test_binning.py +1 -1
- nabu/misc/transpose.py +1 -1
- nabu/misc/unsharp.py +1 -1
- nabu/misc/unsharp_cuda.py +1 -1
- nabu/misc/unsharp_opencl.py +1 -1
- nabu/misc/utils.py +1 -1
- nabu/opencl/fft.py +1 -1
- nabu/opencl/padding.py +1 -1
- nabu/opencl/src/backproj.cl +6 -6
- nabu/opencl/utils.py +8 -8
- nabu/pipeline/config.py +2 -2
- nabu/pipeline/config_validators.py +46 -46
- nabu/pipeline/datadump.py +3 -3
- nabu/pipeline/estimators.py +271 -11
- nabu/pipeline/fullfield/chunked.py +103 -67
- nabu/pipeline/fullfield/chunked_cuda.py +5 -2
- nabu/pipeline/fullfield/computations.py +4 -1
- nabu/pipeline/fullfield/dataset_validator.py +0 -1
- nabu/pipeline/fullfield/get_double_flatfield.py +147 -0
- nabu/pipeline/fullfield/nabu_config.py +36 -17
- nabu/pipeline/fullfield/processconfig.py +41 -7
- nabu/pipeline/fullfield/reconstruction.py +14 -10
- nabu/pipeline/helical/dataset_validator.py +3 -4
- nabu/pipeline/helical/fbp.py +4 -4
- nabu/pipeline/helical/filtering.py +5 -4
- nabu/pipeline/helical/gridded_accumulator.py +10 -11
- nabu/pipeline/helical/helical_chunked_regridded.py +1 -0
- nabu/pipeline/helical/helical_reconstruction.py +12 -9
- nabu/pipeline/helical/helical_utils.py +1 -2
- nabu/pipeline/helical/nabu_config.py +2 -1
- nabu/pipeline/helical/span_strategy.py +1 -0
- nabu/pipeline/helical/weight_balancer.py +2 -3
- nabu/pipeline/params.py +20 -3
- nabu/pipeline/tests/__init__.py +0 -0
- nabu/pipeline/tests/test_estimators.py +240 -3
- nabu/pipeline/utils.py +1 -1
- nabu/pipeline/writer.py +1 -1
- nabu/preproc/alignment.py +0 -10
- nabu/preproc/ccd.py +53 -3
- nabu/preproc/ctf.py +8 -8
- nabu/preproc/ctf_cuda.py +1 -1
- nabu/preproc/double_flatfield_cuda.py +2 -2
- nabu/preproc/double_flatfield_variable_region.py +0 -1
- nabu/preproc/flatfield.py +307 -2
- nabu/preproc/flatfield_cuda.py +1 -2
- nabu/preproc/flatfield_variable_region.py +3 -3
- nabu/preproc/phase.py +2 -4
- nabu/preproc/phase_cuda.py +2 -2
- nabu/preproc/shift.py +4 -2
- nabu/preproc/shift_cuda.py +0 -1
- nabu/preproc/tests/test_ctf.py +4 -4
- nabu/preproc/tests/test_double_flatfield.py +1 -1
- nabu/preproc/tests/test_flatfield.py +1 -1
- nabu/preproc/tests/test_paganin.py +1 -3
- nabu/preproc/tests/test_pcaflats.py +154 -0
- nabu/preproc/tests/test_vshift.py +4 -1
- nabu/processing/azim.py +9 -5
- nabu/processing/convolution_cuda.py +6 -4
- nabu/processing/fft_base.py +7 -3
- nabu/processing/fft_cuda.py +25 -164
- nabu/processing/fft_opencl.py +28 -6
- nabu/processing/fftshift.py +1 -1
- nabu/processing/histogram.py +1 -1
- nabu/processing/muladd.py +0 -1
- nabu/processing/padding_base.py +1 -1
- nabu/processing/padding_cuda.py +0 -2
- nabu/processing/processing_base.py +12 -6
- nabu/processing/rotation_cuda.py +3 -1
- nabu/processing/tests/test_fft.py +2 -64
- nabu/processing/tests/test_fftshift.py +1 -1
- nabu/processing/tests/test_medfilt.py +1 -3
- nabu/processing/tests/test_padding.py +1 -1
- nabu/processing/tests/test_roll.py +1 -1
- nabu/processing/tests/test_rotation.py +4 -2
- nabu/processing/unsharp_opencl.py +1 -1
- nabu/reconstruction/astra.py +245 -0
- nabu/reconstruction/cone.py +39 -9
- nabu/reconstruction/fbp.py +7 -0
- nabu/reconstruction/fbp_base.py +36 -5
- nabu/reconstruction/filtering.py +59 -25
- nabu/reconstruction/filtering_cuda.py +22 -21
- nabu/reconstruction/filtering_opencl.py +10 -14
- nabu/reconstruction/hbp.py +26 -13
- nabu/reconstruction/mlem.py +55 -16
- nabu/reconstruction/projection.py +3 -5
- nabu/reconstruction/sinogram.py +1 -1
- nabu/reconstruction/sinogram_cuda.py +0 -1
- nabu/reconstruction/tests/test_cone.py +37 -2
- nabu/reconstruction/tests/test_deringer.py +4 -4
- nabu/reconstruction/tests/test_fbp.py +36 -15
- nabu/reconstruction/tests/test_filtering.py +27 -7
- nabu/reconstruction/tests/test_halftomo.py +28 -2
- nabu/reconstruction/tests/test_mlem.py +94 -64
- nabu/reconstruction/tests/test_projector.py +7 -2
- nabu/reconstruction/tests/test_reconstructor.py +1 -1
- nabu/reconstruction/tests/test_sino_normalization.py +0 -1
- nabu/resources/dataset_analyzer.py +210 -24
- nabu/resources/gpu.py +4 -4
- nabu/resources/logger.py +4 -4
- nabu/resources/nxflatfield.py +103 -37
- nabu/resources/tests/test_dataset_analyzer.py +37 -0
- nabu/resources/tests/test_extract.py +11 -0
- nabu/resources/tests/test_nxflatfield.py +5 -5
- nabu/resources/utils.py +16 -10
- nabu/stitching/alignment.py +8 -11
- nabu/stitching/config.py +44 -35
- nabu/stitching/definitions.py +2 -2
- nabu/stitching/frame_composition.py +8 -10
- nabu/stitching/overlap.py +4 -4
- nabu/stitching/sample_normalization.py +5 -5
- nabu/stitching/slurm_utils.py +2 -2
- nabu/stitching/stitcher/base.py +2 -0
- nabu/stitching/stitcher/dumper/base.py +0 -1
- nabu/stitching/stitcher/dumper/postprocessing.py +1 -1
- nabu/stitching/stitcher/post_processing.py +11 -9
- nabu/stitching/stitcher/pre_processing.py +37 -31
- nabu/stitching/stitcher/single_axis.py +2 -3
- nabu/stitching/stitcher_2D.py +2 -1
- nabu/stitching/tests/test_config.py +10 -11
- nabu/stitching/tests/test_sample_normalization.py +1 -1
- nabu/stitching/tests/test_slurm_utils.py +1 -2
- nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
- nabu/stitching/tests/test_z_postprocessing_stitching.py +3 -3
- nabu/stitching/tests/test_z_preprocessing_stitching.py +27 -24
- nabu/stitching/utils/tests/__init__.py +0 -0
- nabu/stitching/utils/tests/test_post-processing.py +1 -0
- nabu/stitching/utils/utils.py +16 -18
- nabu/tests.py +0 -3
- nabu/testutils.py +62 -9
- nabu/utils.py +50 -20
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/METADATA +7 -7
- nabu-2025.1.0.dist-info/RECORD +328 -0
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/WHEEL +1 -1
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/entry_points.txt +2 -1
- nabu/app/correct_rot.py +0 -70
- nabu/io/tests/test_detector_distortion.py +0 -178
- nabu-2024.2.14.dist-info/RECORD +0 -317
- /nabu/{stitching → app}/tests/__init__.py +0 -0
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/licenses/LICENSE +0 -0
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/top_level.txt +0 -0
nabu/resources/utils.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
from ast import literal_eval
|
|
2
2
|
import numpy as np
|
|
3
|
+
import pint
|
|
3
4
|
from psutil import virtual_memory, cpu_count
|
|
4
|
-
|
|
5
|
-
|
|
5
|
+
|
|
6
|
+
_ureg = pint.get_application_registry()
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
def get_values_from_file(fname, n_values=None, shape=None, sep=None, any_size=False):
|
|
@@ -163,12 +164,17 @@ def get_quantities_and_units(string, sep=";"):
|
|
|
163
164
|
value, unit = value_and_unit.split()
|
|
164
165
|
val = float(value)
|
|
165
166
|
# Convert to SI
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
167
|
+
if unit.lower() == "kev":
|
|
168
|
+
current_unit = _ureg.keV
|
|
169
|
+
elif unit.lower() == "ev":
|
|
170
|
+
current_unit = _ureg.eV
|
|
171
|
+
else:
|
|
172
|
+
current_unit = _ureg(unit)
|
|
173
|
+
# handle energies (to move to keV)
|
|
174
|
+
if _ureg.keV.dimensionality == current_unit.dimensionality:
|
|
175
|
+
result[quantity_name] = (val * current_unit).to(_ureg.keV).magnitude
|
|
176
|
+
elif _ureg.meter.dimensionality == current_unit.dimensionality:
|
|
177
|
+
result[quantity_name] = (val * current_unit).to_base_units().magnitude
|
|
178
|
+
else:
|
|
179
|
+
raise ValueError(f"Cannot convert: {unit}")
|
|
174
180
|
return result
|
nabu/stitching/alignment.py
CHANGED
|
@@ -1,13 +1,10 @@
|
|
|
1
|
+
from enum import Enum
|
|
1
2
|
import h5py
|
|
2
3
|
import numpy
|
|
3
4
|
from typing import Union
|
|
4
|
-
from silx.utils.enum import Enum as _Enum
|
|
5
|
-
from tomoscan.volumebase import VolumeBase
|
|
6
|
-
from tomoscan.esrf.volume.hdf5volume import HDF5Volume
|
|
7
|
-
from nabu.io.utils import DatasetReader
|
|
8
5
|
|
|
9
6
|
|
|
10
|
-
class AlignmentAxis2(
|
|
7
|
+
class AlignmentAxis2(Enum):
|
|
11
8
|
"""Specific alignment named to help users orienting themself with specific name"""
|
|
12
9
|
|
|
13
10
|
CENTER = "center"
|
|
@@ -15,7 +12,7 @@ class AlignmentAxis2(_Enum):
|
|
|
15
12
|
RIGTH = "right"
|
|
16
13
|
|
|
17
14
|
|
|
18
|
-
class AlignmentAxis1(
|
|
15
|
+
class AlignmentAxis1(Enum):
|
|
19
16
|
"""Specific alignment named to help users orienting themself with specific name"""
|
|
20
17
|
|
|
21
18
|
FRONT = "front"
|
|
@@ -23,7 +20,7 @@ class AlignmentAxis1(_Enum):
|
|
|
23
20
|
BACK = "back"
|
|
24
21
|
|
|
25
22
|
|
|
26
|
-
class _Alignment(
|
|
23
|
+
class _Alignment(Enum):
|
|
27
24
|
"""Internal alignment to be used for 2D alignment"""
|
|
28
25
|
|
|
29
26
|
LOWER_BOUNDARY = "lower boundary"
|
|
@@ -32,7 +29,7 @@ class _Alignment(_Enum):
|
|
|
32
29
|
|
|
33
30
|
@classmethod
|
|
34
31
|
def from_value(cls, value):
|
|
35
|
-
# cast the AlignmentAxis1 and AlignmentAxis2 values to fit the generic definition
|
|
32
|
+
# cast the AlignmentAxis1 and AlignmentAxis2 values to fit the generic definition.
|
|
36
33
|
if value in ("front", "left", AlignmentAxis1.FRONT, AlignmentAxis2.LEFT):
|
|
37
34
|
return _Alignment.LOWER_BOUNDARY
|
|
38
35
|
elif value in ("back", "right", AlignmentAxis1.BACK, AlignmentAxis2.RIGTH):
|
|
@@ -40,7 +37,7 @@ class _Alignment(_Enum):
|
|
|
40
37
|
elif value in (AlignmentAxis1.CENTER, AlignmentAxis2.CENTER):
|
|
41
38
|
return _Alignment.CENTER
|
|
42
39
|
else:
|
|
43
|
-
return super().
|
|
40
|
+
return super().__new__(cls, value)
|
|
44
41
|
|
|
45
42
|
|
|
46
43
|
def align_frame(
|
|
@@ -106,7 +103,7 @@ def align_horizontally(data: numpy.ndarray, alignment: AlignmentAxis2, new_width
|
|
|
106
103
|
:param HAlignment alignment: alignment strategy
|
|
107
104
|
:param int new_width: output data width
|
|
108
105
|
"""
|
|
109
|
-
alignment = AlignmentAxis2
|
|
106
|
+
alignment = AlignmentAxis2(alignment).value
|
|
110
107
|
return align_frame(
|
|
111
108
|
data=data, alignment=alignment, new_aligned_axis_size=new_width, pad_mode=pad_mode, alignment_axis=1
|
|
112
109
|
)
|
|
@@ -151,7 +148,7 @@ class PaddedRawData:
|
|
|
151
148
|
@property
|
|
152
149
|
def shape(self):
|
|
153
150
|
if self._shape is None:
|
|
154
|
-
self._shape = tuple(
|
|
151
|
+
self._shape = tuple( # noqa: C409
|
|
155
152
|
(
|
|
156
153
|
self._raw_data_shape[0],
|
|
157
154
|
numpy.sum(
|
nabu/stitching/config.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
|
+
import pint
|
|
1
2
|
from math import ceil
|
|
2
|
-
from typing import
|
|
3
|
+
from typing import Optional, Union
|
|
4
|
+
from collections.abc import Iterable, Sized
|
|
3
5
|
from dataclasses import dataclass
|
|
4
6
|
import numpy
|
|
5
|
-
from pyunitsystem.metricsystem import MetricSystem
|
|
6
7
|
from nxtomo.paths import nxtomo
|
|
7
8
|
from tomoscan.factory import Factory
|
|
8
9
|
from tomoscan.identifier import VolumeIdentifier, ScanIdentifier
|
|
@@ -12,13 +13,14 @@ from ..pipeline.config_validators import (
|
|
|
12
13
|
convert_to_bool,
|
|
13
14
|
)
|
|
14
15
|
from ..utils import concatenate_dict, convert_str_to_tuple
|
|
15
|
-
from ..io.utils import get_output_volume
|
|
16
16
|
from .overlap import OverlapStitchingStrategy
|
|
17
17
|
from .utils.utils import ShiftAlgorithm
|
|
18
18
|
from .definitions import StitchingType
|
|
19
19
|
from .alignment import AlignmentAxis1, AlignmentAxis2
|
|
20
|
-
from pyunitsystem.metricsystem import MetricSystem
|
|
21
20
|
|
|
21
|
+
_ureg = pint.get_application_registry()
|
|
22
|
+
|
|
23
|
+
# ruff: noqa: S105
|
|
22
24
|
|
|
23
25
|
KEY_IMG_REG_METHOD = "img_reg_method"
|
|
24
26
|
|
|
@@ -128,6 +130,8 @@ SLURM_MODULES_TO_LOADS = "modules"
|
|
|
128
130
|
|
|
129
131
|
SLURM_CLEAN_SCRIPTS = "clean_scripts"
|
|
130
132
|
|
|
133
|
+
SLURM_JOB_NAME = "job_name"
|
|
134
|
+
|
|
131
135
|
# normalization by sample
|
|
132
136
|
|
|
133
137
|
NORMALIZATION_BY_SAMPLE_SECTION = "normalization_by_sample"
|
|
@@ -205,7 +209,7 @@ def _str_to_dict(my_str: Union[str, dict]):
|
|
|
205
209
|
|
|
206
210
|
|
|
207
211
|
def _dict_to_str(ddict: dict):
|
|
208
|
-
return ";".join([f"{
|
|
212
|
+
return ";".join([f"{key!s}={value!s}" for key, value in ddict.items()])
|
|
209
213
|
|
|
210
214
|
|
|
211
215
|
def str_to_shifts(my_str: Optional[str]) -> Union[str, tuple]:
|
|
@@ -218,7 +222,7 @@ def str_to_shifts(my_str: Optional[str]) -> Union[str, tuple]:
|
|
|
218
222
|
if my_str == "":
|
|
219
223
|
return None
|
|
220
224
|
try:
|
|
221
|
-
shift = ShiftAlgorithm
|
|
225
|
+
shift = ShiftAlgorithm(my_str)
|
|
222
226
|
except ValueError:
|
|
223
227
|
shifts_as_str = filter(None, my_str.replace(";", ",").split(","))
|
|
224
228
|
return [float(shift) for shift in shifts_as_str]
|
|
@@ -235,8 +239,8 @@ def _valid_stitching_kernels_params(my_dict: Union[dict, str]):
|
|
|
235
239
|
my_dict = _str_to_dict(my_str=my_dict)
|
|
236
240
|
|
|
237
241
|
valid_keys = (KEY_THRESHOLD_FREQUENCY, KEY_SIDE)
|
|
238
|
-
for key in my_dict
|
|
239
|
-
if not
|
|
242
|
+
for key in my_dict:
|
|
243
|
+
if key not in valid_keys:
|
|
240
244
|
raise KeyError(f"{key} is a unrecognized key")
|
|
241
245
|
return my_dict
|
|
242
246
|
|
|
@@ -253,8 +257,8 @@ def _valid_shifts_params(my_dict: Union[dict, str]):
|
|
|
253
257
|
KEY_LOW_PASS_FILTER,
|
|
254
258
|
KEY_SIDE,
|
|
255
259
|
)
|
|
256
|
-
for key in my_dict
|
|
257
|
-
if not
|
|
260
|
+
for key in my_dict:
|
|
261
|
+
if key not in valid_keys:
|
|
258
262
|
raise KeyError(f"{key} is a unrecognized key")
|
|
259
263
|
return my_dict
|
|
260
264
|
|
|
@@ -334,7 +338,7 @@ class NormalizationBySample:
|
|
|
334
338
|
|
|
335
339
|
@method.setter
|
|
336
340
|
def method(self, method: Union[Method, str]) -> None:
|
|
337
|
-
self._method = Method
|
|
341
|
+
self._method = Method(method)
|
|
338
342
|
|
|
339
343
|
@property
|
|
340
344
|
def margin(self) -> int:
|
|
@@ -351,7 +355,7 @@ class NormalizationBySample:
|
|
|
351
355
|
|
|
352
356
|
@side.setter
|
|
353
357
|
def side(self, side: Union[SampleSide, str]):
|
|
354
|
-
self._side = SampleSide
|
|
358
|
+
self._side = SampleSide(side)
|
|
355
359
|
|
|
356
360
|
@property
|
|
357
361
|
def width(self) -> int:
|
|
@@ -401,16 +405,16 @@ class NormalizationBySample:
|
|
|
401
405
|
NORMALIZATION_BY_SAMPLE_WIDTH: self.width,
|
|
402
406
|
}
|
|
403
407
|
|
|
404
|
-
def __eq__(self,
|
|
405
|
-
if not isinstance(
|
|
408
|
+
def __eq__(self, value: object, /) -> bool:
|
|
409
|
+
if not isinstance(value, NormalizationBySample):
|
|
406
410
|
return False
|
|
407
411
|
else:
|
|
408
|
-
return self.to_dict() ==
|
|
412
|
+
return self.to_dict() == value.to_dict()
|
|
409
413
|
|
|
410
414
|
|
|
411
415
|
@dataclass
|
|
412
416
|
class SlurmConfig:
|
|
413
|
-
"configuration for slurm jobs"
|
|
417
|
+
"""configuration for slurm jobs"""
|
|
414
418
|
|
|
415
419
|
partition: str = "" # note: must stay empty to make by default we don't use slurm (use by the configuration file)
|
|
416
420
|
mem: str = "128"
|
|
@@ -421,6 +425,7 @@ class SlurmConfig:
|
|
|
421
425
|
clean_script: bool = ""
|
|
422
426
|
n_tasks: int = 1
|
|
423
427
|
n_cpu_per_task: int = 4
|
|
428
|
+
job_name: str = ""
|
|
424
429
|
|
|
425
430
|
def __post_init__(self) -> None:
|
|
426
431
|
# make sure either 'modules' or 'preprocessing_command' is provided
|
|
@@ -430,7 +435,7 @@ class SlurmConfig:
|
|
|
430
435
|
)
|
|
431
436
|
|
|
432
437
|
def to_dict(self) -> dict:
|
|
433
|
-
"dump configuration to dict"
|
|
438
|
+
"""dump configuration to dict"""
|
|
434
439
|
return {
|
|
435
440
|
SLURM_PARTITION: self.partition if self.partition is not None else "",
|
|
436
441
|
SLURM_MEM: self.mem,
|
|
@@ -441,6 +446,7 @@ class SlurmConfig:
|
|
|
441
446
|
SLURM_CLEAN_SCRIPTS: self.clean_script,
|
|
442
447
|
SLURM_NUMBER_OF_TASKS: self.n_tasks,
|
|
443
448
|
SLURM_COR_PER_TASKS: self.n_cpu_per_task,
|
|
449
|
+
SLURM_JOB_NAME: self.job_name,
|
|
444
450
|
}
|
|
445
451
|
|
|
446
452
|
@staticmethod
|
|
@@ -457,18 +463,21 @@ class SlurmConfig:
|
|
|
457
463
|
preprocessing_command=config.get(SLURM_PREPROCESSING_COMMAND, ""),
|
|
458
464
|
modules_to_load=convert_str_to_tuple(config.get(SLURM_MODULES_TO_LOADS, "")),
|
|
459
465
|
clean_script=convert_to_bool(config.get(SLURM_CLEAN_SCRIPTS, False))[0],
|
|
466
|
+
job_name=config.get(SLURM_JOB_NAME, ""),
|
|
460
467
|
)
|
|
461
468
|
|
|
462
469
|
|
|
463
|
-
def _cast_shift_to_str(shifts: Union[tuple, str, None]) -> str:
|
|
470
|
+
def _cast_shift_to_str(shifts: Union[tuple, numpy.ndarray, str, None]) -> str:
|
|
464
471
|
if shifts is None:
|
|
465
472
|
return ""
|
|
466
473
|
elif isinstance(shifts, ShiftAlgorithm):
|
|
467
474
|
return shifts.value
|
|
468
475
|
elif isinstance(shifts, str):
|
|
469
476
|
return shifts
|
|
470
|
-
elif isinstance(shifts, (tuple, list)):
|
|
477
|
+
elif isinstance(shifts, (tuple, list, numpy.ndarray)):
|
|
471
478
|
return ";".join([str(value) for value in shifts])
|
|
479
|
+
else:
|
|
480
|
+
raise TypeError(f"unexpected type: {type(shifts)}")
|
|
472
481
|
|
|
473
482
|
|
|
474
483
|
@dataclass
|
|
@@ -541,12 +550,12 @@ class StitchingConfiguration:
|
|
|
541
550
|
STITCHING_SECTION: {
|
|
542
551
|
STITCHING_TYPE_FIELD: {
|
|
543
552
|
"default": StitchingType.Z_PREPROC.value,
|
|
544
|
-
"help": f"stitching to be applied. Must be in {StitchingType
|
|
553
|
+
"help": f"stitching to be applied. Must be in {[st.value for st in StitchingType]}",
|
|
545
554
|
"type": "required",
|
|
546
555
|
},
|
|
547
556
|
STITCHING_STRATEGY_FIELD: {
|
|
548
557
|
"default": "cosinus weights",
|
|
549
|
-
"help": f"Policy to apply to compute the overlap area. Must be in {OverlapStitchingStrategy
|
|
558
|
+
"help": f"Policy to apply to compute the overlap area. Must be in {[ov.value for ov in OverlapStitchingStrategy]}.",
|
|
550
559
|
"type": "required",
|
|
551
560
|
},
|
|
552
561
|
CROSS_CORRELATION_SLICE_FIELD: {
|
|
@@ -626,7 +635,7 @@ class StitchingConfiguration:
|
|
|
626
635
|
},
|
|
627
636
|
ALIGNMENT_AXIS_2_FIELD: {
|
|
628
637
|
"default": "center",
|
|
629
|
-
"help": f"In case frame have different frame widths how to align them (so along volume axis 2). Valid keys are {AlignmentAxis2
|
|
638
|
+
"help": f"In case frame have different frame widths how to align them (so along volume axis 2). Valid keys are {[aa.value for aa in AlignmentAxis2]}",
|
|
630
639
|
"type": "advanced",
|
|
631
640
|
},
|
|
632
641
|
PAD_MODE_FIELD: {
|
|
@@ -748,7 +757,7 @@ class StitchingConfiguration:
|
|
|
748
757
|
AXIS_2_POS_PX: _cast_shift_to_str(self.axis_2_pos_px),
|
|
749
758
|
AXIS_2_POS_MM: _cast_shift_to_str(self.axis_2_pos_mm),
|
|
750
759
|
AXIS_2_PARAMS: _dict_to_str(self.axis_2_params or {}),
|
|
751
|
-
STITCHING_STRATEGY_FIELD: OverlapStitchingStrategy
|
|
760
|
+
STITCHING_STRATEGY_FIELD: OverlapStitchingStrategy(self.stitching_strategy).value,
|
|
752
761
|
FLIP_UD: self.flip_ud,
|
|
753
762
|
FLIP_LR: self.flip_lr,
|
|
754
763
|
RESCALE_FRAMES: self.rescale_frames,
|
|
@@ -927,7 +936,7 @@ class PreProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigurat
|
|
|
927
936
|
if self.pixel_size is None:
|
|
928
937
|
pixel_size_mm = ""
|
|
929
938
|
else:
|
|
930
|
-
pixel_size_mm = self.pixel_size *
|
|
939
|
+
pixel_size_mm = (self.pixel_size * _ureg.meter).to(_ureg.millimeter).magnitude
|
|
931
940
|
return concatenate_dict(
|
|
932
941
|
super().to_dict(),
|
|
933
942
|
{
|
|
@@ -991,10 +1000,10 @@ class PreProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigurat
|
|
|
991
1000
|
if pixel_size == "":
|
|
992
1001
|
pixel_size = None
|
|
993
1002
|
else:
|
|
994
|
-
pixel_size = float(pixel_size)
|
|
1003
|
+
pixel_size = (float(pixel_size) * _ureg.millimeter).to_base_units().magnitude
|
|
995
1004
|
|
|
996
1005
|
return cls(
|
|
997
|
-
stitching_strategy=OverlapStitchingStrategy
|
|
1006
|
+
stitching_strategy=OverlapStitchingStrategy(
|
|
998
1007
|
config[STITCHING_SECTION].get(
|
|
999
1008
|
STITCHING_STRATEGY_FIELD,
|
|
1000
1009
|
OverlapStitchingStrategy.COSINUS_WEIGHTS,
|
|
@@ -1035,7 +1044,7 @@ class PreProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigurat
|
|
|
1035
1044
|
config[STITCHING_SECTION].get(STITCHING_KERNELS_EXTRA_PARAMS, {}),
|
|
1036
1045
|
)
|
|
1037
1046
|
),
|
|
1038
|
-
alignment_axis_2=AlignmentAxis2
|
|
1047
|
+
alignment_axis_2=AlignmentAxis2(
|
|
1039
1048
|
config[STITCHING_SECTION].get(ALIGNMENT_AXIS_2_FIELD, AlignmentAxis2.CENTER)
|
|
1040
1049
|
),
|
|
1041
1050
|
pad_mode=config[STITCHING_SECTION].get(PAD_MODE_FIELD, "constant"),
|
|
@@ -1156,11 +1165,11 @@ class PostProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigura
|
|
|
1156
1165
|
if voxel_size == "":
|
|
1157
1166
|
voxel_size = None
|
|
1158
1167
|
else:
|
|
1159
|
-
voxel_size = float(voxel_size) *
|
|
1168
|
+
voxel_size = (float(voxel_size) * _ureg.millimeter).to_base_units().magnitude
|
|
1160
1169
|
|
|
1161
1170
|
# on the next section the one with a default value qre the optional one
|
|
1162
1171
|
return cls(
|
|
1163
|
-
stitching_strategy=OverlapStitchingStrategy
|
|
1172
|
+
stitching_strategy=OverlapStitchingStrategy(
|
|
1164
1173
|
config[STITCHING_SECTION].get(
|
|
1165
1174
|
STITCHING_STRATEGY_FIELD,
|
|
1166
1175
|
OverlapStitchingStrategy.COSINUS_WEIGHTS,
|
|
@@ -1191,10 +1200,10 @@ class PostProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigura
|
|
|
1191
1200
|
config[STITCHING_SECTION].get(STITCHING_KERNELS_EXTRA_PARAMS, {}),
|
|
1192
1201
|
)
|
|
1193
1202
|
),
|
|
1194
|
-
alignment_axis_1=AlignmentAxis1
|
|
1203
|
+
alignment_axis_1=AlignmentAxis1(
|
|
1195
1204
|
config[STITCHING_SECTION].get(ALIGNMENT_AXIS_1_FIELD, AlignmentAxis1.CENTER)
|
|
1196
1205
|
),
|
|
1197
|
-
alignment_axis_2=AlignmentAxis2
|
|
1206
|
+
alignment_axis_2=AlignmentAxis2(
|
|
1198
1207
|
config[STITCHING_SECTION].get(ALIGNMENT_AXIS_2_FIELD, AlignmentAxis2.CENTER)
|
|
1199
1208
|
),
|
|
1200
1209
|
pad_mode=config[STITCHING_SECTION].get(PAD_MODE_FIELD, "constant"),
|
|
@@ -1208,7 +1217,7 @@ class PostProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigura
|
|
|
1208
1217
|
if self.voxel_size is None:
|
|
1209
1218
|
voxel_size_mm = ""
|
|
1210
1219
|
else:
|
|
1211
|
-
voxel_size_mm = numpy.array(self.voxel_size
|
|
1220
|
+
voxel_size_mm = numpy.array((self.voxel_size * _ureg.meter).to(_ureg.millimeter).magnitude)
|
|
1212
1221
|
|
|
1213
1222
|
return concatenate_dict(
|
|
1214
1223
|
super().to_dict(),
|
|
@@ -1243,7 +1252,7 @@ class PostProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigura
|
|
|
1243
1252
|
STITCHING_SECTION: {
|
|
1244
1253
|
ALIGNMENT_AXIS_1_FIELD: {
|
|
1245
1254
|
"default": "center",
|
|
1246
|
-
"help": f"alignment to apply over axis 1 if needed. Valid values are {AlignmentAxis1
|
|
1255
|
+
"help": f"alignment to apply over axis 1 if needed. Valid values are {[aa for aa in AlignmentAxis1]}",
|
|
1247
1256
|
"type": "advanced",
|
|
1248
1257
|
}
|
|
1249
1258
|
},
|
|
@@ -1274,7 +1283,7 @@ def dict_to_config_obj(config: dict):
|
|
|
1274
1283
|
if stitching_type is None:
|
|
1275
1284
|
raise ValueError("Unable to find stitching type from config dict")
|
|
1276
1285
|
else:
|
|
1277
|
-
stitching_type = StitchingType
|
|
1286
|
+
stitching_type = StitchingType(stitching_type)
|
|
1278
1287
|
if stitching_type is StitchingType.Z_POSTPROC:
|
|
1279
1288
|
return PostProcessedZStitchingConfiguration.from_dict(config)
|
|
1280
1289
|
elif stitching_type is StitchingType.Z_PREPROC:
|
|
@@ -1297,7 +1306,7 @@ def get_default_stitching_config(stitching_type: Optional[Union[StitchingType, s
|
|
|
1297
1306
|
if stitching_type is None:
|
|
1298
1307
|
return concatenate_dict(z_postproc_stitching_config, z_preproc_stitching_config)
|
|
1299
1308
|
|
|
1300
|
-
stitching_type = StitchingType
|
|
1309
|
+
stitching_type = StitchingType(stitching_type)
|
|
1301
1310
|
if stitching_type is StitchingType.Z_POSTPROC:
|
|
1302
1311
|
return z_postproc_stitching_config
|
|
1303
1312
|
elif stitching_type is StitchingType.Z_PREPROC:
|
nabu/stitching/definitions.py
CHANGED
|
@@ -31,7 +31,7 @@ class FrameComposition:
|
|
|
31
31
|
)
|
|
32
32
|
|
|
33
33
|
def compose(self, output_frame: numpy.ndarray, input_frames: tuple):
|
|
34
|
-
if
|
|
34
|
+
if output_frame.ndim not in (2, 3):
|
|
35
35
|
raise TypeError(
|
|
36
36
|
f"output_frame is expected to be 2D (gray scale) or 3D (RGB(A)) and not {output_frame.ndim}"
|
|
37
37
|
)
|
|
@@ -74,9 +74,10 @@ class FrameComposition:
|
|
|
74
74
|
local_start_indices.extend(
|
|
75
75
|
[ceil(key_line[1] + kernel.overlap_size / 2) for (key_line, kernel) in zip(key_lines, overlap_kernels)]
|
|
76
76
|
)
|
|
77
|
-
local_end_indices =
|
|
78
|
-
|
|
79
|
-
|
|
77
|
+
local_end_indices = [
|
|
78
|
+
ceil(key_line[0] - kernel.overlap_size / 2) for (key_line, kernel) in zip(key_lines, overlap_kernels)
|
|
79
|
+
]
|
|
80
|
+
|
|
80
81
|
local_end_indices.append(frames[-1].shape[stitching_axis])
|
|
81
82
|
|
|
82
83
|
for (
|
|
@@ -155,9 +156,6 @@ class FrameComposition:
|
|
|
155
156
|
print(
|
|
156
157
|
f"stitch_frame[{stitch_global_start}:{stitch_global_end}] = stitched_frame_{i_frame}[{stitch_local_start}:{stitch_local_end}]"
|
|
157
158
|
)
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
print(
|
|
162
|
-
f"stitch_frame[{raw_global_start}:{raw_global_end}] = frame_{i_frame}[{raw_local_start}:{raw_local_end}]"
|
|
163
|
-
)
|
|
159
|
+
i_frame += 1
|
|
160
|
+
raw_local_start, raw_local_end, raw_global_start, raw_global_end = list(raw_composition.browse())[-1]
|
|
161
|
+
print(f"stitch_frame[{raw_global_start}:{raw_global_end}] = frame_{i_frame}[{raw_local_start}:{raw_local_end}]")
|
nabu/stitching/overlap.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import numpy
|
|
2
2
|
import logging
|
|
3
3
|
from typing import Optional, Union
|
|
4
|
-
from
|
|
4
|
+
from enum import Enum
|
|
5
5
|
from nabu.misc import fourier_filters
|
|
6
6
|
from scipy.fft import rfftn as local_fftn
|
|
7
7
|
from scipy.fft import irfftn as local_ifftn
|
|
@@ -10,7 +10,7 @@ from tomoscan.utils.geometry import BoundingBox1D
|
|
|
10
10
|
_logger = logging.getLogger(__name__)
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
class OverlapStitchingStrategy(
|
|
13
|
+
class OverlapStitchingStrategy(Enum):
|
|
14
14
|
MEAN = "mean"
|
|
15
15
|
COSINUS_WEIGHTS = "cosinus weights"
|
|
16
16
|
LINEAR_WEIGHTS = "linear weights"
|
|
@@ -64,7 +64,7 @@ class ImageStichOverlapKernel(OverlapKernelBase):
|
|
|
64
64
|
f"frame_width is expected to be a positive int, {frame_unstitched_axis_size} - not {frame_unstitched_axis_size} ({type(frame_unstitched_axis_size)})"
|
|
65
65
|
)
|
|
66
66
|
|
|
67
|
-
if not
|
|
67
|
+
if stitching_axis not in (0, 1):
|
|
68
68
|
raise ValueError(
|
|
69
69
|
"stitching_axis is expected to be the axis along which stitching must be done. It should be '0' or '1'"
|
|
70
70
|
)
|
|
@@ -72,7 +72,7 @@ class ImageStichOverlapKernel(OverlapKernelBase):
|
|
|
72
72
|
self._stitching_axis = stitching_axis
|
|
73
73
|
self._overlap_size = abs(overlap_size)
|
|
74
74
|
self._frame_unstitched_axis_size = frame_unstitched_axis_size
|
|
75
|
-
self._stitching_strategy = OverlapStitchingStrategy
|
|
75
|
+
self._stitching_strategy = OverlapStitchingStrategy(stitching_strategy)
|
|
76
76
|
self._weights_img_1 = None
|
|
77
77
|
self._weights_img_2 = None
|
|
78
78
|
if extra_params is None:
|
|
@@ -1,13 +1,13 @@
|
|
|
1
|
+
from enum import Enum
|
|
1
2
|
import numpy
|
|
2
|
-
from silx.utils.enum import Enum as _Enum
|
|
3
3
|
|
|
4
4
|
|
|
5
|
-
class SampleSide(
|
|
5
|
+
class SampleSide(Enum):
|
|
6
6
|
LEFT = "left"
|
|
7
7
|
RIGHT = "right"
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
class Method(
|
|
10
|
+
class Method(Enum):
|
|
11
11
|
MEAN = "mean"
|
|
12
12
|
MEDIAN = "median"
|
|
13
13
|
|
|
@@ -28,8 +28,8 @@ def normalize_frame(
|
|
|
28
28
|
raise TypeError(f"Frame is expected to be a 2D numpy array.")
|
|
29
29
|
if frame.ndim != 2:
|
|
30
30
|
raise TypeError(f"Frame is expected to be a 2D numpy array. Get {frame.ndim}D")
|
|
31
|
-
side = SampleSide
|
|
32
|
-
method = Method
|
|
31
|
+
side = SampleSide(side)
|
|
32
|
+
method = Method(method)
|
|
33
33
|
|
|
34
34
|
if frame.shape[1] < sample_width + margin_before_sample:
|
|
35
35
|
raise ValueError(
|
nabu/stitching/slurm_utils.py
CHANGED
|
@@ -177,7 +177,7 @@ def split_slices(slices: Union[slice, tuple], n_parts: int):
|
|
|
177
177
|
raise TypeError(f"slices type ({type(slices)}) is not handled. Must be a slice or an Iterable")
|
|
178
178
|
|
|
179
179
|
|
|
180
|
-
def get_working_directory(obj: TomoObject) -> Optional[str]:
|
|
180
|
+
def get_working_directory(obj: TomoObject) -> Optional[str]: # noqa: PLR0911
|
|
181
181
|
"""
|
|
182
182
|
return working directory for a specific TomoObject
|
|
183
183
|
"""
|
|
@@ -201,4 +201,4 @@ def get_working_directory(obj: TomoObject) -> Optional[str]:
|
|
|
201
201
|
else:
|
|
202
202
|
return os.path.abspath(os.path.dirname(obj.master_file))
|
|
203
203
|
else:
|
|
204
|
-
raise RuntimeError(f"obj type not handled ({type(obj)})")
|
|
204
|
+
raise RuntimeError(f"obj type not handled ({type(obj)})") # noqa: TRY004
|
nabu/stitching/stitcher/base.py
CHANGED
|
@@ -21,6 +21,8 @@ def get_obj_constant_side_length(obj: Union[NXtomoScan, VolumeBase], axis: int)
|
|
|
21
21
|
return obj.dim_1
|
|
22
22
|
elif axis in (1, 2):
|
|
23
23
|
return obj.dim_2
|
|
24
|
+
else:
|
|
25
|
+
raise ValueError(f"Axis ({axis}) not handled. Should be in (0, 1, 2)")
|
|
24
26
|
elif isinstance(obj, VolumeBase) and axis == 0:
|
|
25
27
|
return obj.get_volume_shape()[-1]
|
|
26
28
|
else:
|
|
@@ -96,7 +96,7 @@ class OutputVolumeContext(AbstractContextManager):
|
|
|
96
96
|
if self._file_handler is not None:
|
|
97
97
|
return self._file_handler.close()
|
|
98
98
|
else:
|
|
99
|
-
self._volume.save_data()
|
|
99
|
+
self._volume.save_data() # noqa: RET503
|
|
100
100
|
|
|
101
101
|
|
|
102
102
|
class OutputVolumeNoDDContext(OutputVolumeContext):
|
|
@@ -2,6 +2,7 @@ import logging
|
|
|
2
2
|
import numpy
|
|
3
3
|
import os
|
|
4
4
|
import h5py
|
|
5
|
+
import pint
|
|
5
6
|
from typing import Union
|
|
6
7
|
from nabu.stitching.config import PostProcessedSingleAxisStitchingConfiguration
|
|
7
8
|
from nabu.stitching.alignment import AlignmentAxis1
|
|
@@ -13,11 +14,9 @@ from tomoscan.esrf import NXtomoScan
|
|
|
13
14
|
from tomoscan.series import Series
|
|
14
15
|
from tomoscan.volumebase import VolumeBase
|
|
15
16
|
from tomoscan.esrf.volume import HDF5Volume
|
|
16
|
-
from
|
|
17
|
+
from collections.abc import Iterable
|
|
17
18
|
from contextlib import AbstractContextManager
|
|
18
|
-
from pyunitsystem.metricsystem import MetricSystem
|
|
19
19
|
from nabu.stitching.config import (
|
|
20
|
-
PostProcessedSingleAxisStitchingConfiguration,
|
|
21
20
|
KEY_IMG_REG_METHOD,
|
|
22
21
|
)
|
|
23
22
|
from nabu.stitching.utils.utils import find_volumes_relative_shifts
|
|
@@ -26,6 +25,8 @@ from .single_axis import SingleAxisStitcher
|
|
|
26
25
|
|
|
27
26
|
_logger = logging.getLogger(__name__)
|
|
28
27
|
|
|
28
|
+
_ureg = pint.get_application_registry()
|
|
29
|
+
|
|
29
30
|
|
|
30
31
|
class FlippingValueError(ValueError):
|
|
31
32
|
pass
|
|
@@ -267,7 +268,7 @@ class PostProcessingStitching(SingleAxisStitcher):
|
|
|
267
268
|
axis_N_pos_px = []
|
|
268
269
|
for volume, pos_in_mm in zip(self.series, pos_as_mm):
|
|
269
270
|
voxel_size_m = self.configuration.voxel_size or volume.voxel_size
|
|
270
|
-
axis_N_pos_px.append((pos_in_mm
|
|
271
|
+
axis_N_pos_px.append((pos_in_mm * _ureg.millimeter).to_base_units().magnitude / voxel_size_m[0])
|
|
271
272
|
return axis_N_pos_px
|
|
272
273
|
else:
|
|
273
274
|
# deduce from motor position and pixel size
|
|
@@ -426,7 +427,7 @@ class PostProcessingStitching(SingleAxisStitcher):
|
|
|
426
427
|
|
|
427
428
|
bunch_size = 50
|
|
428
429
|
# how many frame to we stitch between two read from disk / save to disk
|
|
429
|
-
with self.dumper.OutputDatasetContext(**output_dataset_args):
|
|
430
|
+
with self.dumper.OutputDatasetContext(**output_dataset_args): # noqa: SIM117
|
|
430
431
|
# note: output_dataset is a HDF5 dataset if final volume is an HDF5 volume else is a numpy array
|
|
431
432
|
with _RawDatasetsContext(
|
|
432
433
|
self._input_volumes,
|
|
@@ -528,7 +529,8 @@ class _RawDatasetsContext(AbstractContextManager):
|
|
|
528
529
|
else:
|
|
529
530
|
data = volume.load_data(store=False)
|
|
530
531
|
if data is None:
|
|
531
|
-
|
|
532
|
+
# TODO
|
|
533
|
+
raise ValueError(f"No data found for volume {volume.get_identifier()}") # noqa: TRY301
|
|
532
534
|
if axis_1_need_padding:
|
|
533
535
|
data = self.add_padding(data=data, axis_1_dim=axis_1_dim, alignment=self.alignment_axis_1)
|
|
534
536
|
datasets.append(data)
|
|
@@ -536,7 +538,7 @@ class _RawDatasetsContext(AbstractContextManager):
|
|
|
536
538
|
# if some errors happen during loading HDF5
|
|
537
539
|
for file_handled in self.__file_handlers:
|
|
538
540
|
file_handled.close()
|
|
539
|
-
raise e
|
|
541
|
+
raise e # noqa: TRY201
|
|
540
542
|
|
|
541
543
|
return datasets
|
|
542
544
|
|
|
@@ -544,11 +546,11 @@ class _RawDatasetsContext(AbstractContextManager):
|
|
|
544
546
|
success = True
|
|
545
547
|
for file_handler in self.__file_handlers:
|
|
546
548
|
success = success and file_handler.close()
|
|
547
|
-
if exc_type is None:
|
|
549
|
+
if exc_type is None: # noqa: RET503
|
|
548
550
|
return success
|
|
549
551
|
|
|
550
552
|
def add_padding(self, data: Union[h5py.Dataset, numpy.ndarray], axis_1_dim, alignment: AlignmentAxis1):
|
|
551
|
-
alignment = AlignmentAxis1
|
|
553
|
+
alignment = AlignmentAxis1(alignment)
|
|
552
554
|
if alignment is AlignmentAxis1.BACK:
|
|
553
555
|
axis_1_pad_width = (axis_1_dim - data.shape[1], 0)
|
|
554
556
|
elif alignment is AlignmentAxis1.CENTER:
|