nabu 2025.1.0.dev5__py3-none-any.whl → 2025.1.0.dev13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nabu/__init__.py +1 -1
- nabu/app/double_flatfield.py +18 -5
- nabu/app/multicor.py +25 -10
- nabu/app/reconstruct_helical.py +4 -4
- nabu/app/stitching.py +7 -2
- nabu/cuda/src/backproj.cu +10 -10
- nabu/cuda/src/cone.cu +4 -0
- nabu/cuda/utils.py +1 -1
- nabu/estimation/cor.py +3 -3
- nabu/io/cast_volume.py +16 -0
- nabu/io/reader.py +3 -2
- nabu/opencl/src/backproj.cl +10 -10
- nabu/pipeline/estimators.py +6 -6
- nabu/pipeline/fullfield/chunked.py +13 -13
- nabu/pipeline/fullfield/computations.py +4 -1
- nabu/pipeline/fullfield/get_double_flatfield.py +147 -0
- nabu/pipeline/fullfield/nabu_config.py +16 -4
- nabu/pipeline/fullfield/processconfig.py +25 -4
- nabu/pipeline/fullfield/reconstruction.py +9 -4
- nabu/pipeline/helical/gridded_accumulator.py +1 -1
- nabu/pipeline/helical/helical_reconstruction.py +2 -2
- nabu/pipeline/helical/nabu_config.py +1 -1
- nabu/pipeline/helical/weight_balancer.py +1 -1
- nabu/pipeline/params.py +8 -3
- nabu/preproc/shift.py +1 -1
- nabu/preproc/tests/test_ctf.py +1 -1
- nabu/preproc/tests/test_paganin.py +1 -3
- nabu/processing/fft_base.py +6 -2
- nabu/processing/fft_cuda.py +17 -167
- nabu/processing/fft_opencl.py +19 -2
- nabu/processing/padding_cuda.py +0 -1
- nabu/processing/processing_base.py +11 -5
- nabu/processing/tests/test_fft.py +1 -63
- nabu/reconstruction/cone.py +39 -9
- nabu/reconstruction/fbp.py +7 -0
- nabu/reconstruction/fbp_base.py +8 -0
- nabu/reconstruction/filtering.py +59 -25
- nabu/reconstruction/filtering_cuda.py +21 -20
- nabu/reconstruction/filtering_opencl.py +8 -14
- nabu/reconstruction/hbp.py +10 -10
- nabu/reconstruction/mlem.py +3 -0
- nabu/reconstruction/rings_cuda.py +41 -13
- nabu/reconstruction/tests/test_cone.py +35 -0
- nabu/reconstruction/tests/test_deringer.py +2 -2
- nabu/reconstruction/tests/test_fbp.py +35 -14
- nabu/reconstruction/tests/test_filtering.py +14 -5
- nabu/reconstruction/tests/test_halftomo.py +1 -1
- nabu/reconstruction/tests/test_reconstructor.py +1 -1
- nabu/resources/dataset_analyzer.py +34 -2
- nabu/resources/tests/test_extract.py +4 -2
- nabu/stitching/config.py +6 -1
- nabu/stitching/stitcher/dumper/__init__.py +1 -0
- nabu/stitching/stitcher/dumper/postprocessing.py +105 -1
- nabu/stitching/stitcher/post_processing.py +14 -4
- nabu/stitching/stitcher/pre_processing.py +1 -1
- nabu/stitching/stitcher/single_axis.py +8 -7
- nabu/stitching/stitcher/z_stitcher.py +8 -4
- nabu/stitching/utils/utils.py +2 -2
- nabu/testutils.py +2 -2
- nabu/utils.py +9 -2
- {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev13.dist-info}/METADATA +9 -28
- {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev13.dist-info}/RECORD +66 -65
- {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev13.dist-info}/WHEEL +1 -1
- {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev13.dist-info}/entry_points.txt +0 -0
- {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev13.dist-info/licenses}/LICENSE +0 -0
- {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev13.dist-info}/top_level.txt +0 -0
@@ -38,7 +38,7 @@ if __do_long_tests__:
|
|
38
38
|
"sigma": [1.0, 2.0],
|
39
39
|
"wname": ["db15", "haar", "rbio4.4"],
|
40
40
|
"padding": [None, (100, 100), (50, 71)],
|
41
|
-
"fft_implem": ["
|
41
|
+
"fft_implem": ["vkfft"],
|
42
42
|
}
|
43
43
|
)
|
44
44
|
|
@@ -107,7 +107,7 @@ class TestDeringer:
|
|
107
107
|
|
108
108
|
@pytest.mark.skipif(
|
109
109
|
not (__has_cuda_deringer__) or munchetal_filter is None,
|
110
|
-
reason="Need pycuda, pycudwt and (
|
110
|
+
reason="Need pycuda, pycudwt and (cupy? or pyvkfft) for this test",
|
111
111
|
)
|
112
112
|
@pytest.mark.parametrize("config", fw_scenarios)
|
113
113
|
def test_cuda_munch_deringer(self, config):
|
@@ -7,10 +7,10 @@ from nabu.testutils import get_data, generate_tests_scenarios, __do_long_tests__
|
|
7
7
|
from nabu.cuda.utils import get_cuda_context, __has_pycuda__
|
8
8
|
from nabu.opencl.utils import get_opencl_context, __has_pyopencl__
|
9
9
|
|
10
|
-
from nabu.processing.fft_cuda import
|
10
|
+
from nabu.processing.fft_cuda import has_vkfft as has_vkfft_cu
|
11
11
|
from nabu.processing.fft_opencl import has_vkfft as has_vkfft_cl
|
12
12
|
|
13
|
-
__has_pycuda__ = __has_pycuda__ and
|
13
|
+
__has_pycuda__ = __has_pycuda__ and has_vkfft_cu()
|
14
14
|
__has_pyopencl__ = __has_pyopencl__ and has_vkfft_cl()
|
15
15
|
|
16
16
|
if __has_pycuda__:
|
@@ -40,7 +40,7 @@ def bootstrap(request):
|
|
40
40
|
# always use contiguous arrays
|
41
41
|
cls.sino_511 = np.ascontiguousarray(cls.sino_512[:, :-1])
|
42
42
|
# Could be set to 5.0e-2 when using textures. When not using textures, interpolation slightly differs
|
43
|
-
cls.tol = 5.1e-2
|
43
|
+
cls.tol = 2e-2 # 5.1e-2
|
44
44
|
|
45
45
|
if __has_pycuda__:
|
46
46
|
cls.cuda_ctx = get_cuda_context(cleanup_at_exit=False)
|
@@ -62,7 +62,7 @@ class TestFBP:
|
|
62
62
|
def _get_backprojector(self, config, *bp_args, **bp_kwargs):
|
63
63
|
if config["backend"] == "cuda":
|
64
64
|
if not (__has_pycuda__):
|
65
|
-
pytest.skip("Need pycuda + (
|
65
|
+
pytest.skip("Need pycuda + (cupy? or pyvkfft)")
|
66
66
|
Backprojector = CudaBackprojector
|
67
67
|
ctx = self.cuda_ctx
|
68
68
|
else:
|
@@ -98,10 +98,14 @@ class TestFBP:
|
|
98
98
|
B = self._get_backprojector(config, (500, 512))
|
99
99
|
res = self.apply_fbp(config, B, self.sino_512)
|
100
100
|
|
101
|
-
|
102
|
-
|
101
|
+
diff = res - self.ref_512
|
102
|
+
tol = self.tol
|
103
|
+
if not (B._use_textures):
|
104
|
+
diff = clip_to_inner_circle(diff)
|
105
|
+
tol = 5.1e-2
|
106
|
+
err_max = np.max(np.abs(diff))
|
103
107
|
|
104
|
-
assert err_max <
|
108
|
+
assert err_max < tol, "Something wrong with config=%s" % (str(config))
|
105
109
|
|
106
110
|
@pytest.mark.parametrize("config", scenarios)
|
107
111
|
def test_fbp_511(self, config):
|
@@ -112,10 +116,29 @@ class TestFBP:
|
|
112
116
|
res = self.apply_fbp(config, B, self.sino_511)
|
113
117
|
ref = self.ref_512[:-1, :-1]
|
114
118
|
|
115
|
-
|
116
|
-
err_max = np.max(np.abs(
|
117
|
-
|
118
|
-
|
119
|
+
diff = clip_to_inner_circle(res - ref)
|
120
|
+
err_max = np.max(np.abs(diff))
|
121
|
+
tol = self.tol
|
122
|
+
if not (B._use_textures):
|
123
|
+
tol = 5.1e-2
|
124
|
+
|
125
|
+
assert err_max < tol, "Something wrong with config=%s" % (str(config))
|
126
|
+
|
127
|
+
# Cropping the singoram to sino[:, :-1] gives a reconstruction
|
128
|
+
# that is not fully equivalent to rec512[:-1, :-1] in the upper half of the image, outside FoV.
|
129
|
+
# However, nabu Backprojector gives the same results as astra
|
130
|
+
# Probably we should check this instead:
|
131
|
+
|
132
|
+
# B = self._get_backprojector(config, (500, 511), rot_center=255.5, extra_options={"centered_axis": True})
|
133
|
+
# res = self.apply_fbp(config, B, self.sino_511)
|
134
|
+
# import astra
|
135
|
+
# proj_geom = astra.create_proj_geom('parallel', 1, 511, B.angles)
|
136
|
+
# proj_geom = astra.geom_postalignment(proj_geom, - 0.5)
|
137
|
+
# vol_geom = astra.create_vol_geom(511, 511)
|
138
|
+
# proj_id = astra.create_projector("cuda", proj_geom, vol_geom)
|
139
|
+
# ref = astra.create_reconstruction("FBP_CUDA", proj_id, self.sino_511, proj_id)[1]
|
140
|
+
# err_max = np.max(np.abs(res - ref))
|
141
|
+
# assert err_max < self.tol, "Something wrong with config=%s" % (str(config))
|
119
142
|
|
120
143
|
@pytest.mark.parametrize("config", scenarios)
|
121
144
|
def test_fbp_roi(self, config):
|
@@ -194,8 +217,7 @@ class TestFBP:
|
|
194
217
|
)
|
195
218
|
res_noclip = B0.fbp(sino)
|
196
219
|
ref = clip_to_inner_circle(res_noclip, radius_factor=1)
|
197
|
-
|
198
|
-
err_max = np.max(abs_diff)
|
220
|
+
err_max = np.max(np.abs(res - ref))
|
199
221
|
assert err_max < tol, "Max error is too high for rot_center=%s ; %s" % (str(rot_center), str(config))
|
200
222
|
|
201
223
|
# Test with custom outer circle value
|
@@ -223,7 +245,6 @@ class TestFBP:
|
|
223
245
|
ref = B0.fbp(self.sino_512)
|
224
246
|
|
225
247
|
# Check that "centered_axis" worked
|
226
|
-
|
227
248
|
B = self._get_backprojector(config, sino.shape, rot_center=rot_center, extra_options={"centered_axis": True})
|
228
249
|
res = self.apply_fbp(config, B, sino)
|
229
250
|
# The outside region (outer circle) is different as "res" is a wider slice
|
@@ -14,11 +14,13 @@ if __has_pyopencl__:
|
|
14
14
|
from nabu.opencl.processing import OpenCLProcessing
|
15
15
|
from nabu.reconstruction.filtering_opencl import OpenCLSinoFilter, __has_vkfft__
|
16
16
|
|
17
|
-
filters_to_test = ["ramlak", "shepp-logan"
|
17
|
+
filters_to_test = ["ramlak", "shepp-logan"]
|
18
18
|
padding_modes_to_test = ["constant", "edge"]
|
19
|
+
crop_filtered_data = [True]
|
19
20
|
if __do_long_tests__:
|
20
|
-
filters_to_test
|
21
|
+
filters_to_test.extend(["cosine", "hamming", "hann", "lanczos"])
|
21
22
|
padding_modes_to_test = SinoFilter.available_padding_modes
|
23
|
+
crop_filtered_data = [True, False]
|
22
24
|
|
23
25
|
tests_scenarios = generate_tests_scenarios(
|
24
26
|
{
|
@@ -26,6 +28,7 @@ tests_scenarios = generate_tests_scenarios(
|
|
26
28
|
"padding_mode": padding_modes_to_test,
|
27
29
|
"output_provided": [True, False],
|
28
30
|
"truncated_sino": [True, False],
|
31
|
+
"crop_filtered_data": crop_filtered_data,
|
29
32
|
}
|
30
33
|
)
|
31
34
|
|
@@ -61,9 +64,10 @@ class TestSinoFilter:
|
|
61
64
|
sino.shape,
|
62
65
|
filter_name=config["filter_name"],
|
63
66
|
padding_mode=config["padding_mode"],
|
67
|
+
crop_filtered_data=config["crop_filtered_data"],
|
64
68
|
)
|
65
69
|
if config["output_provided"]:
|
66
|
-
output = np.
|
70
|
+
output = np.zeros(sino_filter.output_shape, "f")
|
67
71
|
else:
|
68
72
|
output = None
|
69
73
|
res = sino_filter.filter_sino(sino, output=output)
|
@@ -71,7 +75,11 @@ class TestSinoFilter:
|
|
71
75
|
assert id(res) == id(output), "when providing output, return value must not change"
|
72
76
|
|
73
77
|
ref = filter_sinogram(
|
74
|
-
sino,
|
78
|
+
sino,
|
79
|
+
sino_filter.dwidth_padded,
|
80
|
+
filter_name=config["filter_name"],
|
81
|
+
padding_mode=config["padding_mode"],
|
82
|
+
crop_filtered_data=config["crop_filtered_data"],
|
75
83
|
)
|
76
84
|
|
77
85
|
assert np.allclose(res, ref, atol=4e-6)
|
@@ -86,10 +94,11 @@ class TestSinoFilter:
|
|
86
94
|
sino.shape,
|
87
95
|
filter_name=config["filter_name"],
|
88
96
|
padding_mode=config["padding_mode"],
|
97
|
+
crop_filtered_data=config["crop_filtered_data"],
|
89
98
|
cuda_options={"ctx": self.ctx_cuda},
|
90
99
|
)
|
91
100
|
if config["output_provided"]:
|
92
|
-
output = garray.zeros(
|
101
|
+
output = garray.zeros(sino_filter.output_shape, "f")
|
93
102
|
else:
|
94
103
|
output = None
|
95
104
|
res = sino_filter.filter_sino(sino, output=output)
|
@@ -42,7 +42,7 @@ class TestHalftomo:
|
|
42
42
|
def _get_backprojector(self, config, *bp_args, **bp_kwargs):
|
43
43
|
if config["backend"] == "cuda":
|
44
44
|
if not (__has_pycuda__):
|
45
|
-
pytest.skip("Need pycuda +
|
45
|
+
pytest.skip("Need pycuda + cupy? or vkfft")
|
46
46
|
Backprojector = CudaBackprojector
|
47
47
|
ctx = self.cuda_ctx
|
48
48
|
else:
|
@@ -48,7 +48,7 @@ def bootstrap(request):
|
|
48
48
|
)
|
49
49
|
@pytest.mark.usefixtures("bootstrap")
|
50
50
|
class TestReconstructor:
|
51
|
-
@pytest.mark.skipif(not (__has_cuda_fbp__), reason="need pycuda and (
|
51
|
+
@pytest.mark.skipif(not (__has_cuda_fbp__), reason="need pycuda and (cupy? or vkfft)")
|
52
52
|
@pytest.mark.parametrize("config", scenarios)
|
53
53
|
def test_cuda_reconstructor(self, config):
|
54
54
|
data = self.projs
|
@@ -2,10 +2,12 @@ import os
|
|
2
2
|
import numpy as np
|
3
3
|
from silx.io.url import DataUrl
|
4
4
|
from silx.io import get_data
|
5
|
+
from tomoscan import __version__ as __tomoscan_version__
|
5
6
|
from tomoscan.esrf.scan.edfscan import EDFTomoScan
|
6
7
|
from tomoscan.esrf.scan.nxtomoscan import NXtomoScan
|
8
|
+
from packaging.version import parse as parse_version
|
7
9
|
|
8
|
-
from ..utils import check_supported, indices_to_slices
|
10
|
+
from ..utils import BaseClassError, check_supported, indices_to_slices
|
9
11
|
from ..io.reader import EDFStackReader, NXDarksFlats, NXTomoReader
|
10
12
|
from ..io.utils import get_compacted_dataslices
|
11
13
|
from .utils import get_values_from_file, is_hdf5_extension
|
@@ -143,7 +145,10 @@ class DatasetAnalyzer:
|
|
143
145
|
Return the sample-detector distance in meters.
|
144
146
|
"""
|
145
147
|
if self._distance is None:
|
146
|
-
|
148
|
+
if parse_version(__tomoscan_version__) < parse_version("2.2"):
|
149
|
+
self._distance = abs(self.dataset_scanner.distance)
|
150
|
+
else:
|
151
|
+
self._distance = abs(self.dataset_scanner.sample_detector_distance)
|
147
152
|
return self._distance
|
148
153
|
|
149
154
|
@distance.setter
|
@@ -272,6 +277,14 @@ class DatasetAnalyzer:
|
|
272
277
|
def darks(self, val):
|
273
278
|
self._reduced_darks = val
|
274
279
|
|
280
|
+
@property
|
281
|
+
def scan_basename(self):
|
282
|
+
raise BaseClassError
|
283
|
+
|
284
|
+
@property
|
285
|
+
def scan_dirname(self):
|
286
|
+
raise BaseClassError
|
287
|
+
|
275
288
|
|
276
289
|
class EDFDatasetAnalyzer(DatasetAnalyzer):
|
277
290
|
"""
|
@@ -328,6 +341,15 @@ class EDFDatasetAnalyzer(DatasetAnalyzer):
|
|
328
341
|
def get_reader(self, **kwargs):
|
329
342
|
return EDFStackReader(self.files, **kwargs)
|
330
343
|
|
344
|
+
@property
|
345
|
+
def scan_basename(self):
|
346
|
+
# os.path.basename(self.dataset_scanner.path)
|
347
|
+
return self.dataset_scanner.get_dataset_basename()
|
348
|
+
|
349
|
+
@property
|
350
|
+
def scan_dirname(self):
|
351
|
+
return self.dataset_scanner.path
|
352
|
+
|
331
353
|
|
332
354
|
class HDF5DatasetAnalyzer(DatasetAnalyzer):
|
333
355
|
"""
|
@@ -460,6 +482,16 @@ class HDF5DatasetAnalyzer(DatasetAnalyzer):
|
|
460
482
|
def get_reader(self, **kwargs):
|
461
483
|
return NXTomoReader(self.dataset_hdf5_url.file_path(), data_path=self.dataset_hdf5_url.data_path(), **kwargs)
|
462
484
|
|
485
|
+
@property
|
486
|
+
def scan_basename(self):
|
487
|
+
# os.path.splitext(os.path.basename(self.dataset_hdf5_url.file_path()))[0]
|
488
|
+
return self.dataset_scanner.get_dataset_basename()
|
489
|
+
|
490
|
+
@property
|
491
|
+
def scan_dirname(self):
|
492
|
+
# os.path.dirname(di.dataset_hdf5_url.file_path())
|
493
|
+
return self.dataset_scanner.path
|
494
|
+
|
463
495
|
|
464
496
|
def analyze_dataset(dataset_path, extra_options=None, logger=None):
|
465
497
|
if not (os.path.isdir(dataset_path)):
|
@@ -5,5 +5,7 @@ def test_list_match_queries():
|
|
5
5
|
|
6
6
|
# entry0000 .... entry0099
|
7
7
|
avail = ["entry%04d" % i for i in range(100)]
|
8
|
-
|
9
|
-
list_match_queries()
|
8
|
+
assert list_match_queries(avail, "entry0000") == ["entry0000"]
|
9
|
+
assert list_match_queries(avail, ["entry0001"]) == ["entry0001"]
|
10
|
+
assert list_match_queries(avail, ["entry000?"]) == ["entry%04d" % i for i in range(10)]
|
11
|
+
assert list_match_queries(avail, ["entry*"]) == avail
|
nabu/stitching/config.py
CHANGED
@@ -128,6 +128,8 @@ SLURM_MODULES_TO_LOADS = "modules"
|
|
128
128
|
|
129
129
|
SLURM_CLEAN_SCRIPTS = "clean_scripts"
|
130
130
|
|
131
|
+
SLURM_JOB_NAME = "job_name"
|
132
|
+
|
131
133
|
# normalization by sample
|
132
134
|
|
133
135
|
NORMALIZATION_BY_SAMPLE_SECTION = "normalization_by_sample"
|
@@ -421,6 +423,7 @@ class SlurmConfig:
|
|
421
423
|
clean_script: bool = ""
|
422
424
|
n_tasks: int = 1
|
423
425
|
n_cpu_per_task: int = 4
|
426
|
+
job_name: str = ""
|
424
427
|
|
425
428
|
def __post_init__(self) -> None:
|
426
429
|
# make sure either 'modules' or 'preprocessing_command' is provided
|
@@ -441,6 +444,7 @@ class SlurmConfig:
|
|
441
444
|
SLURM_CLEAN_SCRIPTS: self.clean_script,
|
442
445
|
SLURM_NUMBER_OF_TASKS: self.n_tasks,
|
443
446
|
SLURM_COR_PER_TASKS: self.n_cpu_per_task,
|
447
|
+
SLURM_JOB_NAME: self.job_name,
|
444
448
|
}
|
445
449
|
|
446
450
|
@staticmethod
|
@@ -457,6 +461,7 @@ class SlurmConfig:
|
|
457
461
|
preprocessing_command=config.get(SLURM_PREPROCESSING_COMMAND, ""),
|
458
462
|
modules_to_load=convert_str_to_tuple(config.get(SLURM_MODULES_TO_LOADS, "")),
|
459
463
|
clean_script=convert_to_bool(config.get(SLURM_CLEAN_SCRIPTS, False))[0],
|
464
|
+
job_name=config.get(SLURM_JOB_NAME, ""),
|
460
465
|
)
|
461
466
|
|
462
467
|
|
@@ -774,7 +779,7 @@ class SingleAxisConfigMetaClass(type):
|
|
774
779
|
warning: this class is used by tomwer as well
|
775
780
|
"""
|
776
781
|
|
777
|
-
def __new__(mcls, name, bases, attrs, axis=None):
|
782
|
+
def __new__(mcls, name, bases, attrs, axis=None):
|
778
783
|
# assert axis is not None
|
779
784
|
mcls = super().__new__(mcls, name, bases, attrs)
|
780
785
|
mcls._axis = axis
|
@@ -161,7 +161,7 @@ class OutputVolumeNoDDContext(OutputVolumeContext):
|
|
161
161
|
|
162
162
|
class PostProcessingStitchingDumper(DumperBase):
|
163
163
|
"""
|
164
|
-
dumper to be used when save data
|
164
|
+
dumper to be used when save data during post-processing stitching (on reconstructed volume). Output is expected to be an NXtomo
|
165
165
|
"""
|
166
166
|
|
167
167
|
OutputDatasetContext = OutputVolumeContext
|
@@ -220,6 +220,110 @@ class PostProcessingStitchingDumper(DumperBase):
|
|
220
220
|
)
|
221
221
|
|
222
222
|
|
223
|
+
class PostProcessingStitchingDumperWithCache(PostProcessingStitchingDumper):
|
224
|
+
"""
|
225
|
+
PostProcessingStitchingDumper with intermediate cache in order to speed up writting.
|
226
|
+
The cache is save to disk when full or when closing the dumper.
|
227
|
+
Mostly convenient for HDF5
|
228
|
+
"""
|
229
|
+
|
230
|
+
def __init__(self, configuration):
|
231
|
+
super().__init__(configuration)
|
232
|
+
self.__cache = None
|
233
|
+
"""cache as a numpy.ndarray"""
|
234
|
+
self.__cache_size = None
|
235
|
+
"""how many frame do we want to keep in memory before dumping to disk"""
|
236
|
+
self.__dump_axis = None
|
237
|
+
"""axis along which we load / save the data. Different of the stitching axis"""
|
238
|
+
self.__final_volume_shape = None
|
239
|
+
self.__output_frame_index = 0
|
240
|
+
self.__cache_index = 0
|
241
|
+
|
242
|
+
def init_cache(self, dump_axis, size, dtype):
|
243
|
+
if dump_axis not in (0, 1, 2):
|
244
|
+
raise ValueError(f"axis should be in (0, 1, 2). Got {dump_axis}")
|
245
|
+
|
246
|
+
self.__dump_axis = dump_axis
|
247
|
+
self.__cache_size = size
|
248
|
+
self.__cache = numpy.empty(
|
249
|
+
self._get_cache_shape(),
|
250
|
+
dtype=dtype,
|
251
|
+
)
|
252
|
+
|
253
|
+
def reset_cache(self):
|
254
|
+
self.__cache_index = 0
|
255
|
+
|
256
|
+
def set_final_volume_shape(self, shape):
|
257
|
+
self.__final_volume_shape = shape
|
258
|
+
|
259
|
+
def _get_cache_shape(self):
|
260
|
+
assert self.__final_volume_shape is not None, "final volume shape should already be defined"
|
261
|
+
if self.__dump_axis == 0:
|
262
|
+
return (
|
263
|
+
self.__cache_size,
|
264
|
+
self.__final_volume_shape[1],
|
265
|
+
self.__final_volume_shape[2],
|
266
|
+
)
|
267
|
+
elif self.__dump_axis == 1:
|
268
|
+
return (
|
269
|
+
self.__final_volume_shape[0],
|
270
|
+
self.__cache_size,
|
271
|
+
self.__final_volume_shape[2],
|
272
|
+
)
|
273
|
+
elif self.__dump_axis == 2:
|
274
|
+
return (
|
275
|
+
self.__final_volume_shape[0],
|
276
|
+
self.__final_volume_shape[1],
|
277
|
+
self.__cache_size,
|
278
|
+
)
|
279
|
+
else:
|
280
|
+
raise RuntimeError("dump axis should be defined before using the cache")
|
281
|
+
|
282
|
+
def save_stitched_frame(
|
283
|
+
self,
|
284
|
+
stitched_frame: numpy.ndarray,
|
285
|
+
composition_cls: dict,
|
286
|
+
i_frame: int,
|
287
|
+
axis: int,
|
288
|
+
):
|
289
|
+
"""save the frame to the volume. In this use case save the frame to the buffer. Waiting to be dump later.
|
290
|
+
We expect 'save_stitched_frame' to be called with contiguous frames (in the output volume space)
|
291
|
+
"""
|
292
|
+
index_cache = self.__cache_index
|
293
|
+
if self.__dump_axis == 0:
|
294
|
+
self.__cache[index_cache,] = stitched_frame
|
295
|
+
elif self.__dump_axis == 1:
|
296
|
+
self.__cache[:, index_cache, :] = stitched_frame
|
297
|
+
elif self.__dump_axis == 2:
|
298
|
+
self.__cache[:, :, index_cache] = stitched_frame
|
299
|
+
else:
|
300
|
+
raise RuntimeError("dump axis should be defined before using the cache")
|
301
|
+
self.__cache_index += 1
|
302
|
+
|
303
|
+
def dump_cache(self, nb_frames):
|
304
|
+
"""
|
305
|
+
dump the first nb_frames to disk
|
306
|
+
"""
|
307
|
+
output_dataset_start_index = self.__output_frame_index
|
308
|
+
output_dataset_end_index = self.__output_frame_index + nb_frames
|
309
|
+
if self.__dump_axis == 0:
|
310
|
+
self.output_dataset[output_dataset_start_index:output_dataset_end_index] = self.__cache[:nb_frames]
|
311
|
+
elif self.__dump_axis == 1:
|
312
|
+
self.output_dataset[
|
313
|
+
:,
|
314
|
+
output_dataset_start_index:output_dataset_end_index,
|
315
|
+
] = self.__cache[:, :nb_frames]
|
316
|
+
elif self.__dump_axis == 2:
|
317
|
+
self.output_dataset[:, :, output_dataset_start_index:output_dataset_end_index] = self.__cache[
|
318
|
+
:, :, :nb_frames
|
319
|
+
]
|
320
|
+
else:
|
321
|
+
raise RuntimeError("dump axis should be defined before using the cache")
|
322
|
+
|
323
|
+
self.__output_frame_index = output_dataset_end_index
|
324
|
+
self.reset_cache()
|
325
|
+
|
326
|
+
|
223
327
|
class PostProcessingStitchingDumperNoDD(PostProcessingStitchingDumper):
|
224
328
|
"""
|
225
329
|
same as PostProcessingStitchingDumper but prevent to do data duplication.
|
@@ -396,7 +396,7 @@ class PostProcessingStitching(SingleAxisStitcher):
|
|
396
396
|
|
397
397
|
data_type = self.get_output_data_type()
|
398
398
|
|
399
|
-
if self.progress:
|
399
|
+
if self.progress is not None:
|
400
400
|
self.progress.total = final_volume_shape[1]
|
401
401
|
|
402
402
|
y_index = 0
|
@@ -411,7 +411,7 @@ class PostProcessingStitching(SingleAxisStitcher):
|
|
411
411
|
"dtype": data_type,
|
412
412
|
"dumper": self.dumper,
|
413
413
|
}
|
414
|
-
from .dumper.postprocessing import PostProcessingStitchingDumperNoDD
|
414
|
+
from .dumper.postprocessing import PostProcessingStitchingDumperNoDD, PostProcessingStitchingDumperWithCache
|
415
415
|
|
416
416
|
# TODO: FIXME: for now not very elegant but in the case of avoiding data duplication
|
417
417
|
# we need to provide the the information about the stitched part shape.
|
@@ -420,7 +420,11 @@ class PostProcessingStitching(SingleAxisStitcher):
|
|
420
420
|
output_dataset_args["stitching_sources_arr_shapes"] = tuple(
|
421
421
|
[(abs(overlap), n_slices, self._stitching_constant_length) for overlap in self._axis_0_rel_final_shifts]
|
422
422
|
)
|
423
|
+
elif isinstance(self.dumper, PostProcessingStitchingDumperWithCache):
|
424
|
+
self.dumper.set_final_volume_shape(final_volume_shape)
|
423
425
|
|
426
|
+
bunch_size = 50
|
427
|
+
# how many frame to we stitch between two read from disk / save to disk
|
424
428
|
with self.dumper.OutputDatasetContext(**output_dataset_args): # noqa: SIM117
|
425
429
|
# note: output_dataset is a HDF5 dataset if final volume is an HDF5 volume else is a numpy array
|
426
430
|
with _RawDatasetsContext(
|
@@ -430,11 +434,14 @@ class PostProcessingStitching(SingleAxisStitcher):
|
|
430
434
|
# note: raw_datasets can be numpy arrays or HDF5 dataset (in the case of HDF5Volume)
|
431
435
|
# to speed up we read by bunch of dataset. For numpy array this doesn't change anything
|
432
436
|
# but for HDF5 dataset this can speed up a lot the processing (depending on HDF5 dataset chuncks)
|
433
|
-
# note: we read
|
437
|
+
# note: we read through axis 1
|
434
438
|
if isinstance(self.dumper, PostProcessingStitchingDumperNoDD):
|
435
439
|
self.dumper.raw_regions_hdf5_dataset = raw_datasets
|
440
|
+
if isinstance(self.dumper, PostProcessingStitchingDumperWithCache):
|
441
|
+
self.dumper.init_cache(dump_axis=1, dtype=data_type, size=bunch_size)
|
442
|
+
|
436
443
|
for bunch_start, bunch_end in PostProcessingStitching._data_bunch_iterator(
|
437
|
-
slices=self._slices_to_stitch, bunch_size=
|
444
|
+
slices=self._slices_to_stitch, bunch_size=bunch_size
|
438
445
|
):
|
439
446
|
for data_frames in PostProcessingStitching._get_bunch_of_data(
|
440
447
|
bunch_start,
|
@@ -469,6 +476,9 @@ class PostProcessingStitching(SingleAxisStitcher):
|
|
469
476
|
self.progress.update()
|
470
477
|
y_index += 1
|
471
478
|
|
479
|
+
if isinstance(self.dumper, PostProcessingStitchingDumperWithCache):
|
480
|
+
self.dumper.dump_cache(nb_frames=(bunch_end - bunch_start))
|
481
|
+
|
472
482
|
# alias to general API
|
473
483
|
def _create_stitching(self, store_composition):
|
474
484
|
self._create_stitched_volume(store_composition=store_composition)
|
@@ -677,7 +677,7 @@ class PreProcessingStitching(SingleAxisStitcher):
|
|
677
677
|
scans_projections_indexes = []
|
678
678
|
for scan, reverse in zip(self.series, self.reading_orders):
|
679
679
|
scans_projections_indexes.append(sorted(scan.projections.keys(), reverse=(reverse == -1)))
|
680
|
-
if self.progress:
|
680
|
+
if self.progress is not None:
|
681
681
|
self.progress.total = self.get_n_slices_to_stitch()
|
682
682
|
|
683
683
|
if isinstance(self._slices_to_stitch, slice):
|
@@ -37,7 +37,7 @@ class _SingleAxisMetaClass(type):
|
|
37
37
|
Metaclass for single axis stitcher in order to aggregate dumper class and axis
|
38
38
|
"""
|
39
39
|
|
40
|
-
def __new__(mcls, name, bases, attrs, axis=None, dumper_cls=None):
|
40
|
+
def __new__(mcls, name, bases, attrs, axis=None, dumper_cls=None):
|
41
41
|
mcls = super().__new__(mcls, name, bases, attrs)
|
42
42
|
mcls._axis = axis
|
43
43
|
mcls._dumperCls = dumper_cls
|
@@ -470,12 +470,13 @@ class SingleAxisStitcher(_StitcherBase, metaclass=_SingleAxisMetaClass):
|
|
470
470
|
pad_mode=pad_mode,
|
471
471
|
new_unstitched_axis_size=new_width,
|
472
472
|
)
|
473
|
-
dumper
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
473
|
+
if dumper is not None:
|
474
|
+
dumper.save_stitched_frame(
|
475
|
+
stitched_frame=stitched_frame,
|
476
|
+
composition_cls=composition_cls,
|
477
|
+
i_frame=i_frame,
|
478
|
+
axis=1,
|
479
|
+
)
|
479
480
|
|
480
481
|
if return_composition_cls:
|
481
482
|
return stitched_frame, composition_cls
|
@@ -1,6 +1,10 @@
|
|
1
1
|
from nabu.stitching.stitcher.pre_processing import PreProcessingStitching
|
2
2
|
from nabu.stitching.stitcher.post_processing import PostProcessingStitching
|
3
|
-
from .dumper import
|
3
|
+
from .dumper import (
|
4
|
+
PreProcessingStitchingDumper,
|
5
|
+
PostProcessingStitchingDumperNoDD,
|
6
|
+
PostProcessingStitchingDumperWithCache,
|
7
|
+
)
|
4
8
|
from nabu.stitching.stitcher.single_axis import _SingleAxisMetaClass
|
5
9
|
|
6
10
|
|
@@ -26,12 +30,12 @@ class PreProcessingZStitcher(
|
|
26
30
|
class PostProcessingZStitcher(
|
27
31
|
PostProcessingStitching,
|
28
32
|
metaclass=_SingleAxisMetaClass,
|
29
|
-
dumper_cls=
|
33
|
+
dumper_cls=PostProcessingStitchingDumperWithCache,
|
30
34
|
axis=0,
|
31
35
|
):
|
32
36
|
@property
|
33
37
|
def serie_label(self) -> str:
|
34
|
-
return "z-
|
38
|
+
return "z-series"
|
35
39
|
|
36
40
|
|
37
41
|
class PostProcessingZStitcherNoDD(
|
@@ -42,4 +46,4 @@ class PostProcessingZStitcherNoDD(
|
|
42
46
|
):
|
43
47
|
@property
|
44
48
|
def serie_label(self) -> str:
|
45
|
-
return "z-
|
49
|
+
return "z-series"
|
nabu/stitching/utils/utils.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from
|
1
|
+
from packaging.version import parse as parse_version
|
2
2
|
from typing import Optional, Union
|
3
3
|
import logging
|
4
4
|
import functools
|
@@ -521,7 +521,7 @@ def find_shift_with_itk(img1: numpy.ndarray, img2: numpy.ndarray) -> tuple:
|
|
521
521
|
_logger.warning("itk is not installed. Please install it to find shift with it")
|
522
522
|
return (0, 0)
|
523
523
|
|
524
|
-
if
|
524
|
+
if parse_version(itk.Version.GetITKVersion()) < parse_version("4.9.0"):
|
525
525
|
_logger.error("ITK 4.9.0 is required to find shift with it.")
|
526
526
|
return (0, 0)
|
527
527
|
|
nabu/testutils.py
CHANGED
@@ -14,14 +14,14 @@ __big_testdata_dir__ = os.environ.get("NABU_BIGDATA_DIR")
|
|
14
14
|
if __big_testdata_dir__ is None or not (os.path.isdir(__big_testdata_dir__)):
|
15
15
|
__big_testdata_dir__ = None
|
16
16
|
|
17
|
-
__do_long_tests__ = os.environ.get("NABU_LONG_TESTS", False)
|
17
|
+
__do_long_tests__ = os.environ.get("NABU_LONG_TESTS", False) # noqa: PLW1508
|
18
18
|
if __do_long_tests__:
|
19
19
|
try:
|
20
20
|
__do_long_tests__ = bool(int(__do_long_tests__))
|
21
21
|
except:
|
22
22
|
__do_long_tests__ = False
|
23
23
|
|
24
|
-
__do_large_mem_tests__ = os.environ.get("NABU_LARGE_MEM_TESTS", False)
|
24
|
+
__do_large_mem_tests__ = os.environ.get("NABU_LARGE_MEM_TESTS", False) # noqa: PLW1508
|
25
25
|
if __do_large_mem_tests__:
|
26
26
|
try:
|
27
27
|
__do_large_mem_tests__ = bool(int(__do_large_mem_tests__))
|
nabu/utils.py
CHANGED
@@ -180,6 +180,8 @@ def list_match_queries(available, queries):
|
|
180
180
|
Given a list of strings, return all items matching any of one elements of "queries"
|
181
181
|
"""
|
182
182
|
matches = []
|
183
|
+
if isinstance(queries, str):
|
184
|
+
queries = [queries]
|
183
185
|
for a in available:
|
184
186
|
for q in queries:
|
185
187
|
if fnmatch(a, q):
|
@@ -711,9 +713,9 @@ def concatenate_dict(dict_1, dict_2) -> dict:
|
|
711
713
|
return res
|
712
714
|
|
713
715
|
|
714
|
-
class BaseClassError:
|
716
|
+
class BaseClassError(BaseException):
|
715
717
|
def __init__(self, *args, **kwargs):
|
716
|
-
raise
|
718
|
+
raise NotImplementedError("Base class")
|
717
719
|
|
718
720
|
|
719
721
|
def MissingComponentError(msg):
|
@@ -796,6 +798,11 @@ def median2(img):
|
|
796
798
|
# ---------------------------- Decorators --------------------------------------
|
797
799
|
# ------------------------------------------------------------------------------
|
798
800
|
|
801
|
+
|
802
|
+
def no_decorator(func):
|
803
|
+
return func
|
804
|
+
|
805
|
+
|
799
806
|
_warnings = {}
|
800
807
|
|
801
808
|
|