nabu 2023.2.1__py3-none-any.whl → 2024.1.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doc/conf.py +1 -1
- doc/doc_config.py +32 -0
- nabu/__init__.py +2 -1
- nabu/app/bootstrap_stitching.py +1 -1
- nabu/app/cli_configs.py +122 -2
- nabu/app/composite_cor.py +27 -2
- nabu/app/correct_rot.py +70 -0
- nabu/app/create_distortion_map_from_poly.py +42 -18
- nabu/app/diag_to_pix.py +358 -0
- nabu/app/diag_to_rot.py +449 -0
- nabu/app/generate_header.py +4 -3
- nabu/app/histogram.py +2 -2
- nabu/app/multicor.py +6 -1
- nabu/app/parse_reconstruction_log.py +151 -0
- nabu/app/prepare_weights_double.py +83 -22
- nabu/app/reconstruct.py +5 -1
- nabu/app/reconstruct_helical.py +7 -0
- nabu/app/reduce_dark_flat.py +6 -3
- nabu/app/rotate.py +4 -4
- nabu/app/stitching.py +16 -2
- nabu/app/tests/test_reduce_dark_flat.py +18 -2
- nabu/app/validator.py +4 -4
- nabu/cuda/convolution.py +8 -376
- nabu/cuda/fft.py +4 -0
- nabu/cuda/kernel.py +4 -4
- nabu/cuda/medfilt.py +5 -158
- nabu/cuda/padding.py +5 -71
- nabu/cuda/processing.py +23 -2
- nabu/cuda/src/ElementOp.cu +78 -0
- nabu/cuda/src/backproj.cu +28 -2
- nabu/cuda/src/fourier_wavelets.cu +2 -2
- nabu/cuda/src/normalization.cu +23 -0
- nabu/cuda/src/padding.cu +2 -2
- nabu/cuda/src/transpose.cu +16 -0
- nabu/cuda/utils.py +39 -0
- nabu/estimation/alignment.py +10 -1
- nabu/estimation/cor.py +808 -38
- nabu/estimation/cor_sino.py +7 -9
- nabu/estimation/tests/test_cor.py +85 -3
- nabu/io/reader.py +26 -18
- nabu/io/tests/test_cast_volume.py +3 -3
- nabu/io/tests/test_detector_distortion.py +3 -3
- nabu/io/tiffwriter_zmm.py +2 -2
- nabu/io/utils.py +14 -4
- nabu/io/writer.py +5 -3
- nabu/misc/fftshift.py +6 -0
- nabu/misc/histogram.py +5 -285
- nabu/misc/histogram_cuda.py +8 -104
- nabu/misc/kernel_base.py +3 -121
- nabu/misc/padding_base.py +5 -69
- nabu/misc/processing_base.py +3 -107
- nabu/misc/rotation.py +5 -62
- nabu/misc/rotation_cuda.py +5 -65
- nabu/misc/transpose.py +6 -0
- nabu/misc/unsharp.py +3 -78
- nabu/misc/unsharp_cuda.py +5 -52
- nabu/misc/unsharp_opencl.py +8 -85
- nabu/opencl/fft.py +6 -0
- nabu/opencl/kernel.py +21 -6
- nabu/opencl/padding.py +5 -72
- nabu/opencl/processing.py +27 -5
- nabu/opencl/src/backproj.cl +3 -3
- nabu/opencl/src/fftshift.cl +65 -12
- nabu/opencl/src/padding.cl +2 -2
- nabu/opencl/src/roll.cl +96 -0
- nabu/opencl/src/transpose.cl +16 -0
- nabu/pipeline/config_validators.py +63 -3
- nabu/pipeline/dataset_validator.py +2 -2
- nabu/pipeline/estimators.py +193 -35
- nabu/pipeline/fullfield/chunked.py +34 -17
- nabu/pipeline/fullfield/chunked_cuda.py +7 -5
- nabu/pipeline/fullfield/computations.py +48 -13
- nabu/pipeline/fullfield/nabu_config.py +13 -13
- nabu/pipeline/fullfield/processconfig.py +10 -5
- nabu/pipeline/fullfield/reconstruction.py +1 -2
- nabu/pipeline/helical/fbp.py +5 -0
- nabu/pipeline/helical/filtering.py +12 -9
- nabu/pipeline/helical/gridded_accumulator.py +179 -33
- nabu/pipeline/helical/helical_chunked_regridded.py +262 -151
- nabu/pipeline/helical/helical_chunked_regridded_cuda.py +4 -11
- nabu/pipeline/helical/helical_reconstruction.py +56 -18
- nabu/pipeline/helical/span_strategy.py +1 -1
- nabu/pipeline/helical/tests/test_accumulator.py +4 -0
- nabu/pipeline/params.py +23 -2
- nabu/pipeline/processconfig.py +3 -8
- nabu/pipeline/tests/test_chunk_reader.py +78 -0
- nabu/pipeline/tests/test_estimators.py +120 -2
- nabu/pipeline/utils.py +25 -0
- nabu/pipeline/writer.py +2 -0
- nabu/preproc/ccd_cuda.py +9 -7
- nabu/preproc/ctf.py +21 -26
- nabu/preproc/ctf_cuda.py +25 -25
- nabu/preproc/double_flatfield.py +14 -2
- nabu/preproc/double_flatfield_cuda.py +7 -11
- nabu/preproc/flatfield_cuda.py +23 -27
- nabu/preproc/phase.py +19 -24
- nabu/preproc/phase_cuda.py +21 -21
- nabu/preproc/shift_cuda.py +58 -28
- nabu/preproc/tests/test_ctf.py +5 -5
- nabu/preproc/tests/test_double_flatfield.py +2 -2
- nabu/preproc/tests/test_vshift.py +13 -2
- nabu/processing/__init__.py +0 -0
- nabu/processing/convolution_cuda.py +375 -0
- nabu/processing/fft_base.py +163 -0
- nabu/processing/fft_cuda.py +256 -0
- nabu/processing/fft_opencl.py +54 -0
- nabu/processing/fftshift.py +134 -0
- nabu/processing/histogram.py +286 -0
- nabu/processing/histogram_cuda.py +103 -0
- nabu/processing/kernel_base.py +126 -0
- nabu/processing/medfilt_cuda.py +159 -0
- nabu/processing/muladd.py +29 -0
- nabu/processing/muladd_cuda.py +68 -0
- nabu/processing/padding_base.py +71 -0
- nabu/processing/padding_cuda.py +75 -0
- nabu/processing/padding_opencl.py +77 -0
- nabu/processing/processing_base.py +123 -0
- nabu/processing/roll_opencl.py +64 -0
- nabu/processing/rotation.py +63 -0
- nabu/processing/rotation_cuda.py +66 -0
- nabu/processing/tests/__init__.py +0 -0
- nabu/processing/tests/test_fft.py +268 -0
- nabu/processing/tests/test_fftshift.py +71 -0
- nabu/{misc → processing}/tests/test_histogram.py +2 -4
- nabu/{cuda → processing}/tests/test_medfilt.py +1 -1
- nabu/processing/tests/test_muladd.py +54 -0
- nabu/{cuda → processing}/tests/test_padding.py +119 -75
- nabu/processing/tests/test_roll.py +63 -0
- nabu/{misc → processing}/tests/test_rotation.py +3 -2
- nabu/processing/tests/test_transpose.py +72 -0
- nabu/{misc → processing}/tests/test_unsharp.py +41 -8
- nabu/processing/transpose.py +126 -0
- nabu/processing/unsharp.py +79 -0
- nabu/processing/unsharp_cuda.py +53 -0
- nabu/processing/unsharp_opencl.py +75 -0
- nabu/reconstruction/fbp.py +34 -10
- nabu/reconstruction/fbp_base.py +35 -16
- nabu/reconstruction/fbp_opencl.py +7 -12
- nabu/reconstruction/filtering.py +2 -2
- nabu/reconstruction/filtering_cuda.py +13 -14
- nabu/reconstruction/filtering_opencl.py +3 -4
- nabu/reconstruction/projection.py +2 -0
- nabu/reconstruction/rings.py +158 -1
- nabu/reconstruction/rings_cuda.py +218 -58
- nabu/reconstruction/sinogram_cuda.py +16 -12
- nabu/reconstruction/tests/test_deringer.py +116 -14
- nabu/reconstruction/tests/test_fbp.py +22 -31
- nabu/reconstruction/tests/test_filtering.py +11 -2
- nabu/resources/dataset_analyzer.py +89 -26
- nabu/resources/nxflatfield.py +2 -2
- nabu/resources/tests/test_nxflatfield.py +1 -1
- nabu/resources/utils.py +9 -2
- nabu/stitching/alignment.py +184 -0
- nabu/stitching/config.py +241 -39
- nabu/stitching/definitions.py +6 -0
- nabu/stitching/frame_composition.py +4 -2
- nabu/stitching/overlap.py +99 -3
- nabu/stitching/sample_normalization.py +60 -0
- nabu/stitching/slurm_utils.py +10 -10
- nabu/stitching/tests/test_alignment.py +99 -0
- nabu/stitching/tests/test_config.py +16 -1
- nabu/stitching/tests/test_overlap.py +68 -2
- nabu/stitching/tests/test_sample_normalization.py +49 -0
- nabu/stitching/tests/test_slurm_utils.py +5 -5
- nabu/stitching/tests/test_utils.py +3 -33
- nabu/stitching/tests/test_z_stitching.py +391 -22
- nabu/stitching/utils.py +144 -202
- nabu/stitching/z_stitching.py +309 -126
- nabu/testutils.py +18 -0
- nabu/thirdparty/tomocupy_remove_stripe.py +586 -0
- nabu/utils.py +32 -6
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/LICENSE +1 -1
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/METADATA +5 -5
- nabu-2024.1.0rc3.dist-info/RECORD +296 -0
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/WHEEL +1 -1
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/entry_points.txt +5 -1
- nabu/conftest.py +0 -14
- nabu/opencl/fftshift.py +0 -92
- nabu/opencl/tests/test_fftshift.py +0 -55
- nabu/opencl/tests/test_padding.py +0 -84
- nabu-2023.2.1.dist-info/RECORD +0 -252
- /nabu/cuda/src/{fftshift.cu → dfi_fftshift.cu} +0 -0
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/top_level.txt +0 -0
nabu/stitching/utils.py
CHANGED
@@ -1,18 +1,24 @@
|
|
1
1
|
from distutils.version import StrictVersion
|
2
|
-
|
3
2
|
from typing import Optional, Union
|
4
3
|
import logging
|
4
|
+
import functools
|
5
5
|
import numpy
|
6
|
+
from scipy.ndimage import affine_transform
|
6
7
|
from tomoscan.scanbase import TomoScanBase
|
7
|
-
from
|
8
|
-
from
|
9
|
-
from nabu.estimation.alignment import AlignmentBase
|
10
|
-
from nabu.resources.dataset_analyzer import HDF5DatasetAnalyzer
|
8
|
+
from tomoscan.volumebase import VolumeBase
|
9
|
+
from nxtomo.utils.transformation import build_matrix, UDDetTransformation
|
11
10
|
from silx.utils.enum import Enum as _Enum
|
12
|
-
from scipy.ndimage import shift as scipy_shift
|
13
11
|
from scipy.fft import rfftn as local_fftn
|
14
12
|
from scipy.fft import irfftn as local_ifftn
|
15
|
-
from
|
13
|
+
from silx.utils.enum import Enum as _Enum
|
14
|
+
from nxtomo.utils.transformation import build_matrix, UDDetTransformation
|
15
|
+
from tomoscan.scanbase import TomoScanBase
|
16
|
+
from .overlap import OverlapStitchingStrategy, ZStichOverlapKernel
|
17
|
+
from .alignment import AlignmentAxis1, AlignmentAxis2, PaddedRawData
|
18
|
+
from ..misc import fourier_filters
|
19
|
+
from ..estimation.alignment import AlignmentBase
|
20
|
+
from ..resources.dataset_analyzer import HDF5DatasetAnalyzer
|
21
|
+
from ..resources.nxflatfield import update_dataset_info_flats_darks
|
16
22
|
|
17
23
|
try:
|
18
24
|
import itk
|
@@ -40,7 +46,6 @@ class ShiftAlgorithm(_Enum):
|
|
40
46
|
|
41
47
|
NABU_FFT = "nabu-fft"
|
42
48
|
SKIMAGE = "skimage"
|
43
|
-
SHIFT_GRID = "shift-grid"
|
44
49
|
ITK_IMG_REG_V4 = "itk-img-reg-v4"
|
45
50
|
NONE = "None"
|
46
51
|
|
@@ -94,7 +99,6 @@ def find_frame_relative_shifts(
|
|
94
99
|
):
|
95
100
|
from nabu.stitching.config import (
|
96
101
|
KEY_WINDOW_SIZE,
|
97
|
-
KEY_SCORE_METHOD,
|
98
102
|
KEY_LOW_PASS_FILTER,
|
99
103
|
KEY_HIGH_PASS_FILTER,
|
100
104
|
) # avoid cyclic import
|
@@ -144,52 +148,49 @@ def find_frame_relative_shifts(
|
|
144
148
|
initial_shifts = numpy.array(estimated_shifts).copy()
|
145
149
|
extra_shifts = numpy.array([0.0, 0.0])
|
146
150
|
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
(y_cross_correlation_function, x_cross_correlation_function),
|
151
|
-
(y_shifts_params, x_shifts_params),
|
152
|
-
):
|
153
|
-
if method is ShiftAlgorithm.NABU_FFT:
|
154
|
-
extra_shifts[axis] = find_shift_correlate(img1=overlap_upper_frame, img2=overlap_lower_frame)[axis]
|
155
|
-
elif method is ShiftAlgorithm.SKIMAGE:
|
156
|
-
if not __has_sk_phase_correlation__:
|
157
|
-
raise ValueError("scikit-image not installed. Cannot do phase correlation from it")
|
158
|
-
else:
|
159
|
-
found_shift, _, _ = phase_cross_correlation(
|
160
|
-
reference_image=overlap_upper_frame, moving_image=overlap_lower_frame, space="real"
|
161
|
-
)
|
162
|
-
extra_shifts[axis] = found_shift[axis]
|
163
|
-
elif method is ShiftAlgorithm.NONE: # None as a string in case some uers give this value
|
164
|
-
# in the case we don't want to apply algorithm keep the initial 'guessed' shifts
|
165
|
-
continue
|
166
|
-
elif method is ShiftAlgorithm.SHIFT_GRID:
|
167
|
-
if axis == 0:
|
168
|
-
window_size = (int(y_shifts_params.get(KEY_WINDOW_SIZE, 200)), 0)
|
169
|
-
elif axis == 1:
|
170
|
-
window_size = (0, int(x_shifts_params.get(KEY_WINDOW_SIZE, 200)))
|
171
|
-
score_method = params.get(KEY_SCORE_METHOD, ScoreMethod.STD)
|
172
|
-
extra_shifts[axis] = -shift_grid_search(
|
173
|
-
img_1=overlap_upper_frame,
|
174
|
-
img_2=overlap_lower_frame,
|
175
|
-
window_sizes=window_size,
|
176
|
-
step_size=1,
|
177
|
-
axis=(axis,),
|
178
|
-
score_method=score_method,
|
179
|
-
)[axis]
|
180
|
-
elif method is ShiftAlgorithm.ITK_IMG_REG_V4:
|
181
|
-
extra_shifts[axis] = find_shift_with_itk(img1=overlap_upper_frame, img2=overlap_lower_frame)[axis]
|
151
|
+
def skimage_proxy(img1, img2):
|
152
|
+
if not __has_sk_phase_correlation__:
|
153
|
+
raise ValueError("scikit-image not installed. Cannot do phase correlation from it")
|
182
154
|
else:
|
183
|
-
|
184
|
-
|
155
|
+
found_shift, _, _ = phase_cross_correlation(reference_image=img1, moving_image=img2, space="real")
|
156
|
+
return -found_shift
|
157
|
+
|
158
|
+
shift_methods = {
|
159
|
+
ShiftAlgorithm.NABU_FFT: functools.partial(
|
160
|
+
find_shift_correlate, img1=overlap_upper_frame, img2=overlap_lower_frame
|
161
|
+
),
|
162
|
+
ShiftAlgorithm.SKIMAGE: functools.partial(skimage_proxy, img1=overlap_upper_frame, img2=overlap_lower_frame),
|
163
|
+
ShiftAlgorithm.ITK_IMG_REG_V4: functools.partial(
|
164
|
+
find_shift_with_itk, img1=overlap_upper_frame, img2=overlap_lower_frame
|
165
|
+
),
|
166
|
+
ShiftAlgorithm.NONE: functools.partial(lambda: (0.0, 0.0)),
|
167
|
+
}
|
168
|
+
|
169
|
+
res_algo = {}
|
170
|
+
for shift_alg in set((x_cross_correlation_function, y_cross_correlation_function)):
|
171
|
+
if shift_alg not in shift_methods:
|
172
|
+
raise ValueError(f"requested image alignment function not handled ({shift_alg})")
|
173
|
+
try:
|
174
|
+
res_algo[shift_alg] = shift_methods[shift_alg]()
|
175
|
+
except Exception as e:
|
176
|
+
_logger.error(f"Failed to find shift from {shift_alg.value}. Error is {e}")
|
177
|
+
res_algo[shift_alg] = (0, 0)
|
178
|
+
|
179
|
+
extra_shifts = (
|
180
|
+
res_algo[y_cross_correlation_function][0],
|
181
|
+
res_algo[x_cross_correlation_function][1],
|
182
|
+
)
|
185
183
|
|
184
|
+
final_rel_shifts = numpy.array(extra_shifts) + initial_shifts
|
186
185
|
return tuple([int(shift) for shift in final_rel_shifts])
|
187
186
|
|
188
187
|
|
189
188
|
def find_volumes_relative_shifts(
|
190
|
-
upper_volume:
|
191
|
-
lower_volume:
|
189
|
+
upper_volume: VolumeBase,
|
190
|
+
lower_volume: VolumeBase,
|
192
191
|
estimated_shifts,
|
192
|
+
dim_axis_1: int,
|
193
|
+
dtype,
|
193
194
|
flip_ud_upper_frame: bool = False,
|
194
195
|
flip_ud_lower_frame: bool = False,
|
195
196
|
slice_for_shift: Union[int, str] = "middle",
|
@@ -197,15 +198,59 @@ def find_volumes_relative_shifts(
|
|
197
198
|
y_cross_correlation_function=None,
|
198
199
|
x_shifts_params: Optional[dict] = None,
|
199
200
|
y_shifts_params: Optional[dict] = None,
|
201
|
+
alignment_axis_2="center",
|
202
|
+
alignment_axis_1="center",
|
200
203
|
):
|
204
|
+
"""
|
205
|
+
|
206
|
+
:param int dim_axis_1: axis 1 dimension (to handle axis 1 alignment)
|
207
|
+
"""
|
201
208
|
if y_shifts_params is None:
|
202
209
|
y_shifts_params = {}
|
203
210
|
|
204
211
|
if x_shifts_params is None:
|
205
212
|
x_shifts_params = {}
|
206
213
|
|
207
|
-
|
208
|
-
|
214
|
+
alignment_axis_2 = AlignmentAxis2.from_value(alignment_axis_2)
|
215
|
+
alignment_axis_1 = AlignmentAxis1.from_value(alignment_axis_1)
|
216
|
+
assert dim_axis_1 > 0, "dim_axis_1 <= 0"
|
217
|
+
|
218
|
+
if isinstance(slice_for_shift, str):
|
219
|
+
if slice_for_shift == "first":
|
220
|
+
slice_for_shift = 0
|
221
|
+
elif slice_for_shift == "last":
|
222
|
+
slice_for_shift = dim_axis_1
|
223
|
+
elif slice_for_shift == "middle":
|
224
|
+
slice_for_shift = dim_axis_1 // 2
|
225
|
+
else:
|
226
|
+
raise ValueError("invalid slice provided to search shift", slice_for_shift)
|
227
|
+
|
228
|
+
def get_slice_along_axis_1(volume: VolumeBase, index: int):
|
229
|
+
assert isinstance(index, int), f"index should be an int, {type(index)} provided"
|
230
|
+
volume_shape = volume.get_volume_shape()
|
231
|
+
if alignment_axis_1 is AlignmentAxis1.BACK:
|
232
|
+
front_empty_width = dim_axis_1 - volume_shape[1]
|
233
|
+
if index < front_empty_width:
|
234
|
+
return PaddedRawData.get_empty_frame(shape=(volume_shape[0], volume_shape[2]), dtype=dtype)
|
235
|
+
else:
|
236
|
+
return volume.get_slice(index=index - front_empty_width, axis=1)
|
237
|
+
elif alignment_axis_1 is AlignmentAxis1.FRONT:
|
238
|
+
if index >= volume_shape[1]:
|
239
|
+
return PaddedRawData.get_empty_frame(shape=(volume_shape[0], volume_shape[2]), dtype=dtype)
|
240
|
+
else:
|
241
|
+
return volume.get_slice(index=index, axis=1)
|
242
|
+
elif alignment_axis_1 is AlignmentAxis1.CENTER:
|
243
|
+
front_empty_width = (dim_axis_1 - volume_shape[1]) // 2
|
244
|
+
back_empty_width = dim_axis_1 - front_empty_width
|
245
|
+
if index < front_empty_width or index > back_empty_width:
|
246
|
+
return PaddedRawData.get_empty_frame(shape=(volume_shape[0], volume_shape[2]), dtype=dtype)
|
247
|
+
else:
|
248
|
+
return volume.get_slice(index=index - front_empty_width, axis=1)
|
249
|
+
else:
|
250
|
+
raise TypeError(f"unmanaged alignment mode {alignment_axis_1.value}")
|
251
|
+
|
252
|
+
upper_frame = get_slice_along_axis_1(upper_volume, index=slice_for_shift)
|
253
|
+
lower_frame = get_slice_along_axis_1(lower_volume, index=slice_for_shift)
|
209
254
|
if flip_ud_upper_frame:
|
210
255
|
upper_frame = numpy.flipud(upper_frame.copy())
|
211
256
|
if flip_ud_lower_frame:
|
@@ -214,14 +259,35 @@ def find_volumes_relative_shifts(
|
|
214
259
|
from nabu.stitching.config import KEY_WINDOW_SIZE # avoid cyclic import
|
215
260
|
|
216
261
|
w_window_size = int(y_shifts_params.get(KEY_WINDOW_SIZE, 400))
|
217
|
-
start_overlap = max(estimated_shifts[0] - w_window_size // 2, 0)
|
218
|
-
end_overlap = min(estimated_shifts[0] + w_window_size // 2, min(upper_frame.shape[0], lower_frame.shape[0]))
|
262
|
+
start_overlap = max(estimated_shifts[0] // 2 - w_window_size // 2, 0)
|
263
|
+
end_overlap = min(estimated_shifts[0] // 2 + w_window_size // 2, min(upper_frame.shape[0], lower_frame.shape[0]))
|
219
264
|
|
220
265
|
if start_overlap == 0:
|
221
266
|
overlap_upper_frame = upper_frame[-end_overlap:]
|
222
267
|
else:
|
223
268
|
overlap_upper_frame = upper_frame[-end_overlap:-start_overlap]
|
224
269
|
overlap_lower_frame = lower_frame[start_overlap:end_overlap]
|
270
|
+
|
271
|
+
# align if necessary
|
272
|
+
if overlap_upper_frame.shape[1] != overlap_lower_frame.shape[1]:
|
273
|
+
overlap_frame_width = min(overlap_upper_frame.shape[1], overlap_lower_frame.shape[1])
|
274
|
+
if alignment_axis_2 is AlignmentAxis2.CENTER:
|
275
|
+
upper_frame_left_pos = overlap_upper_frame.shape[1] // 2 - overlap_frame_width // 2
|
276
|
+
upper_frame_right_pos = upper_frame_left_pos + overlap_frame_width
|
277
|
+
overlap_upper_frame = overlap_upper_frame[:, upper_frame_left_pos:upper_frame_right_pos]
|
278
|
+
|
279
|
+
lower_frame_left_pos = overlap_lower_frame.shape[1] // 2 - overlap_frame_width // 2
|
280
|
+
lower_frame_right_pos = lower_frame_left_pos + overlap_frame_width
|
281
|
+
overlap_lower_frame = overlap_lower_frame[:, lower_frame_left_pos:lower_frame_right_pos]
|
282
|
+
elif alignment_axis_2 is AlignmentAxis2.LEFT:
|
283
|
+
overlap_upper_frame = overlap_upper_frame[:, :overlap_frame_width]
|
284
|
+
overlap_lower_frame = overlap_lower_frame[:, :overlap_frame_width]
|
285
|
+
elif alignment_axis_2 is AlignmentAxis2.RIGTH:
|
286
|
+
overlap_upper_frame = overlap_upper_frame[:, -overlap_frame_width:]
|
287
|
+
overlap_lower_frame = overlap_lower_frame[:, -overlap_frame_width:]
|
288
|
+
else:
|
289
|
+
raise ValueError(f"Alignement {alignment_axis_2.value} is not handled")
|
290
|
+
|
225
291
|
if not overlap_upper_frame.shape == overlap_lower_frame.shape:
|
226
292
|
raise ValueError(f"Fail to get consistant overlap ({overlap_upper_frame.shape} vs {overlap_lower_frame.shape})")
|
227
293
|
|
@@ -293,7 +359,6 @@ def find_projections_relative_shifts(
|
|
293
359
|
):
|
294
360
|
cor_options = x_shifts_params.copy()
|
295
361
|
cor_options.pop("img_reg_method", None)
|
296
|
-
cor_options.pop("score_method", None)
|
297
362
|
# remove all none numeric options because estimate_cor will call 'literal_eval' on them
|
298
363
|
|
299
364
|
upper_scan_dataset_info = HDF5DatasetAnalyzer(
|
@@ -326,16 +391,24 @@ def find_projections_relative_shifts(
|
|
326
391
|
|
327
392
|
# } else we will compute shift from the flat projections
|
328
393
|
|
329
|
-
def get_flat_fielded_proj(
|
394
|
+
def get_flat_fielded_proj(
|
395
|
+
scan: TomoScanBase, proj_index: int, reverse: bool, transformation_matrix: Optional[numpy.ndarray]
|
396
|
+
):
|
330
397
|
first_proj_idx = sorted(lower_scan.projections.keys(), reverse=reverse)[proj_index]
|
331
398
|
ff = scan.flat_field_correction(
|
332
399
|
(scan.projections[first_proj_idx],),
|
333
400
|
(first_proj_idx,),
|
334
401
|
)[0]
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
402
|
+
assert ff.ndim == 2, f"expects a single 2D frame. Get something with {ff.ndim} dimensions"
|
403
|
+
if transformation_matrix is not None:
|
404
|
+
assert (
|
405
|
+
transformation_matrix.ndim == 2
|
406
|
+
), f"expects a 2D transformation matrix. Get a {transformation_matrix.ndim} D"
|
407
|
+
if numpy.isclose(transformation_matrix[2, 2], -1):
|
408
|
+
transformation_matrix[2, :] = 0
|
409
|
+
transformation_matrix[0, 2] = 0
|
410
|
+
transformation_matrix[2, 2] = 1
|
411
|
+
ff = numpy.flipud(ff)
|
339
412
|
return ff
|
340
413
|
|
341
414
|
if isinstance(projection_for_shift, str):
|
@@ -357,26 +430,32 @@ def find_projections_relative_shifts(
|
|
357
430
|
f"projection_for_shift is expected to be an int. Not {type(projection_for_shift)} - {projection_for_shift}"
|
358
431
|
)
|
359
432
|
|
433
|
+
upper_scan_transformations = list(upper_scan.get_detector_transformations(tuple()))
|
434
|
+
if flip_ud_upper_frame:
|
435
|
+
upper_scan_transformations.append(UDDetTransformation())
|
436
|
+
upper_scan_trans_matrix = build_matrix(upper_scan_transformations)
|
437
|
+
lower_scan_transformations = list(lower_scan.get_detector_transformations(tuple()))
|
438
|
+
if flip_ud_lower_frame:
|
439
|
+
lower_scan_transformations.append(UDDetTransformation())
|
440
|
+
lower_scan_trans_matrix = build_matrix(lower_scan_transformations)
|
360
441
|
upper_proj = get_flat_fielded_proj(
|
361
442
|
upper_scan,
|
362
443
|
projection_for_shift,
|
363
444
|
reverse=False,
|
364
|
-
|
365
|
-
revert_y=upper_scan.get_y_flipped(default=False) ^ flip_ud_upper_frame,
|
445
|
+
transformation_matrix=upper_scan_trans_matrix,
|
366
446
|
)
|
367
447
|
lower_proj = get_flat_fielded_proj(
|
368
448
|
lower_scan,
|
369
449
|
projection_for_shift,
|
370
450
|
reverse=invert_order,
|
371
|
-
|
372
|
-
revert_y=lower_scan.get_y_flipped(default=False) ^ flip_ud_lower_frame,
|
451
|
+
transformation_matrix=lower_scan_trans_matrix,
|
373
452
|
)
|
374
453
|
|
375
454
|
from nabu.stitching.config import KEY_WINDOW_SIZE # avoid cyclic import
|
376
455
|
|
377
456
|
w_window_size = int(y_shifts_params.get(KEY_WINDOW_SIZE, 400))
|
378
|
-
start_overlap = max(estimated_shifts[0] - w_window_size // 2, 0)
|
379
|
-
end_overlap = min(estimated_shifts[0] + w_window_size // 2, min(upper_proj.shape[0], lower_proj.shape[0]))
|
457
|
+
start_overlap = max(estimated_shifts[0] // 2 - w_window_size // 2, 0)
|
458
|
+
end_overlap = min(estimated_shifts[0] // 2 + w_window_size // 2, min(upper_proj.shape[0], lower_proj.shape[0]))
|
380
459
|
if start_overlap == 0:
|
381
460
|
overlap_upper_frame = upper_proj[-end_overlap:]
|
382
461
|
else:
|
@@ -410,144 +489,7 @@ def find_shift_correlate(img1, img2, padding_mode="reflect"):
|
|
410
489
|
|
411
490
|
(f_vals, fv, fh) = alignment.extract_peak_region_2d(cc, cc_vs=cc_vs, cc_hs=cc_hs)
|
412
491
|
shifts_vh = alignment.refine_max_position_2d(f_vals, fv, fh)
|
413
|
-
return shifts_vh
|
414
|
-
|
415
|
-
|
416
|
-
class ScoreMethod(_Enum):
|
417
|
-
STD = "standard deviation"
|
418
|
-
TV = "total variation"
|
419
|
-
TV_INVERSE = "1 / (total variation)"
|
420
|
-
STD_INVERSE = "1 / std"
|
421
|
-
|
422
|
-
@classmethod
|
423
|
-
def from_value(cls, value):
|
424
|
-
if isinstance(value, str):
|
425
|
-
# for string handle the case where value as been provided as 'value'. As there is spaces this can happen
|
426
|
-
value = value.lstrip("'").rstrip("'")
|
427
|
-
if value in ("tv", "TV"):
|
428
|
-
return ScoreMethod.TV
|
429
|
-
elif value in ("std", "STD"):
|
430
|
-
return ScoreMethod.STD
|
431
|
-
else:
|
432
|
-
return super().from_value(value=value)
|
433
|
-
|
434
|
-
|
435
|
-
def compute_score_contrast_std(data: numpy.ndarray):
|
436
|
-
"""
|
437
|
-
Compute a contrast score by simply computing the standard deviation of
|
438
|
-
the frame
|
439
|
-
:param numpy.ndarray data: frame for which we should compute the score
|
440
|
-
:return: score of the frame
|
441
|
-
:rtype: float
|
442
|
-
"""
|
443
|
-
if data is None:
|
444
|
-
return None
|
445
|
-
else:
|
446
|
-
return data.std()
|
447
|
-
|
448
|
-
|
449
|
-
def compute_tv_score(data: numpy.ndarray):
|
450
|
-
"""
|
451
|
-
Compute the data score as image total variation
|
452
|
-
|
453
|
-
:param numpy.ndarray data: frame for which we should compute the score
|
454
|
-
:return: score of the frame
|
455
|
-
:rtype: float
|
456
|
-
"""
|
457
|
-
tv = numpy.sum(numpy.sqrt(numpy.gradient(data, axis=0) ** 2 + numpy.gradient(data, axis=1) ** 2))
|
458
|
-
return tv
|
459
|
-
|
460
|
-
|
461
|
-
def compute_score(img_1, img_2, shift, score_method, score_region, return_img=False):
|
462
|
-
score_method = ScoreMethod.from_value(score_method)
|
463
|
-
|
464
|
-
img_2 = scipy_shift(img_2, shift=shift)
|
465
|
-
|
466
|
-
img_2_reduce = img_2[
|
467
|
-
score_region[0].start : score_region[0].stop,
|
468
|
-
score_region[1].start : score_region[1].stop,
|
469
|
-
]
|
470
|
-
img_1_reduce = img_1[
|
471
|
-
score_region[0].start : score_region[0].stop,
|
472
|
-
score_region[1].start : score_region[1].stop,
|
473
|
-
]
|
474
|
-
|
475
|
-
img_sum = img_1_reduce * 0.5 + img_2_reduce * 0.5
|
476
|
-
|
477
|
-
if score_method is ScoreMethod.TV:
|
478
|
-
result = compute_tv_score(img_sum)
|
479
|
-
elif score_method is ScoreMethod.STD:
|
480
|
-
result = compute_score_contrast_std(img_sum)
|
481
|
-
elif score_method is ScoreMethod.TV_INVERSE:
|
482
|
-
result = 1 / compute_tv_score(img_sum)
|
483
|
-
elif score_method is ScoreMethod.STD_INVERSE:
|
484
|
-
result = 1 / compute_score_contrast_std(img_sum)
|
485
|
-
else:
|
486
|
-
raise ValueError(f"{score_method} is not handled")
|
487
|
-
if return_img:
|
488
|
-
return result, img_sum
|
489
|
-
else:
|
490
|
-
return result
|
491
|
-
|
492
|
-
|
493
|
-
def shift_grid_search(img_1, img_2, window_sizes: tuple, axis, step_size, score_method=ScoreMethod.STD):
|
494
|
-
"""
|
495
|
-
we could consider adding weights to do the exact same operation that will be done for the stitching... ? overkilled ?
|
496
|
-
:param tuple window_sizes: as y_size, x_size
|
497
|
-
"""
|
498
|
-
if not isinstance(window_sizes, tuple) and len(window_sizes) != 2:
|
499
|
-
raise TypeError(f"window_sizes is expected to be a tuple of two ints. {window_sizes} provided")
|
500
|
-
if not img_1.ndim == img_2.ndim == 2:
|
501
|
-
raise ValueError("image dimension should be 2D")
|
502
|
-
|
503
|
-
for value in axis:
|
504
|
-
if not value in (0, 1):
|
505
|
-
raise ValueError(f"axis {value} is not handled")
|
506
|
-
half_window_x = min((img_1.shape[1], img_2.shape[1], abs(window_sizes[1]))) // 2
|
507
|
-
half_window_y = min((img_1.shape[0], img_2.shape[0], abs(window_sizes[0]))) // 2
|
508
|
-
if 1 in axis:
|
509
|
-
x_research = numpy.arange(-half_window_x, half_window_x, step_size)
|
510
|
-
x_score_region = slice(
|
511
|
-
int(half_window_x * 2),
|
512
|
-
int(min(img_1.shape[1], img_2.shape[1]) - (half_window_x * 2)),
|
513
|
-
)
|
514
|
-
else:
|
515
|
-
x_research = tuple((0,))
|
516
|
-
x_score_region = slice(0, int(min(img_1.shape[1], img_2.shape[1])))
|
517
|
-
|
518
|
-
if 0 in axis:
|
519
|
-
y_research = numpy.arange(-half_window_y, half_window_y, step_size)
|
520
|
-
y_score_region = slice(
|
521
|
-
int((half_window_y * 2)),
|
522
|
-
int(min(img_1.shape[0], img_2.shape[0]) - (half_window_y * 2)),
|
523
|
-
)
|
524
|
-
else:
|
525
|
-
y_research = tuple((0,))
|
526
|
-
y_score_region = slice(0, int(min(img_1.shape[0], img_2.shape[0])))
|
527
|
-
score_region = (
|
528
|
-
y_score_region,
|
529
|
-
x_score_region,
|
530
|
-
)
|
531
|
-
|
532
|
-
score_width = score_region[1].stop - score_region[1].start
|
533
|
-
if score_width < 10:
|
534
|
-
_logger.warning("score_width seems very low. Try reducing window_sizes")
|
535
|
-
|
536
|
-
best_score, best_shift = None, (0, 0)
|
537
|
-
for x_shift in x_research:
|
538
|
-
for y_shift in y_research:
|
539
|
-
res = compute_score(
|
540
|
-
img_1,
|
541
|
-
img_2.copy(),
|
542
|
-
shift=(y_shift, x_shift),
|
543
|
-
score_method=score_method,
|
544
|
-
score_region=score_region,
|
545
|
-
)
|
546
|
-
local_score = res
|
547
|
-
if best_score is None or (local_score is not None and local_score > best_score):
|
548
|
-
best_score = local_score
|
549
|
-
best_shift = (y_shift, x_shift)
|
550
|
-
return numpy.array(best_shift)
|
492
|
+
return -shifts_vh
|
551
493
|
|
552
494
|
|
553
495
|
def find_shift_with_itk(img1: numpy.ndarray, img2: numpy.ndarray) -> tuple:
|