nabu 2023.2.1__py3-none-any.whl → 2024.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (183) hide show
  1. doc/conf.py +1 -1
  2. doc/doc_config.py +32 -0
  3. nabu/__init__.py +2 -1
  4. nabu/app/bootstrap_stitching.py +1 -1
  5. nabu/app/cli_configs.py +122 -2
  6. nabu/app/composite_cor.py +27 -2
  7. nabu/app/correct_rot.py +70 -0
  8. nabu/app/create_distortion_map_from_poly.py +42 -18
  9. nabu/app/diag_to_pix.py +358 -0
  10. nabu/app/diag_to_rot.py +449 -0
  11. nabu/app/generate_header.py +4 -3
  12. nabu/app/histogram.py +2 -2
  13. nabu/app/multicor.py +6 -1
  14. nabu/app/parse_reconstruction_log.py +151 -0
  15. nabu/app/prepare_weights_double.py +83 -22
  16. nabu/app/reconstruct.py +5 -1
  17. nabu/app/reconstruct_helical.py +7 -0
  18. nabu/app/reduce_dark_flat.py +6 -3
  19. nabu/app/rotate.py +4 -4
  20. nabu/app/stitching.py +16 -2
  21. nabu/app/tests/test_reduce_dark_flat.py +18 -2
  22. nabu/app/validator.py +4 -4
  23. nabu/cuda/convolution.py +8 -376
  24. nabu/cuda/fft.py +4 -0
  25. nabu/cuda/kernel.py +4 -4
  26. nabu/cuda/medfilt.py +5 -158
  27. nabu/cuda/padding.py +5 -71
  28. nabu/cuda/processing.py +23 -2
  29. nabu/cuda/src/ElementOp.cu +78 -0
  30. nabu/cuda/src/backproj.cu +28 -2
  31. nabu/cuda/src/fourier_wavelets.cu +2 -2
  32. nabu/cuda/src/normalization.cu +23 -0
  33. nabu/cuda/src/padding.cu +2 -2
  34. nabu/cuda/src/transpose.cu +16 -0
  35. nabu/cuda/utils.py +39 -0
  36. nabu/estimation/alignment.py +10 -1
  37. nabu/estimation/cor.py +808 -38
  38. nabu/estimation/cor_sino.py +7 -9
  39. nabu/estimation/tests/test_cor.py +85 -3
  40. nabu/io/reader.py +26 -18
  41. nabu/io/tests/test_cast_volume.py +3 -3
  42. nabu/io/tests/test_detector_distortion.py +3 -3
  43. nabu/io/tiffwriter_zmm.py +2 -2
  44. nabu/io/utils.py +14 -4
  45. nabu/io/writer.py +5 -3
  46. nabu/misc/fftshift.py +6 -0
  47. nabu/misc/histogram.py +5 -285
  48. nabu/misc/histogram_cuda.py +8 -104
  49. nabu/misc/kernel_base.py +3 -121
  50. nabu/misc/padding_base.py +5 -69
  51. nabu/misc/processing_base.py +3 -107
  52. nabu/misc/rotation.py +5 -62
  53. nabu/misc/rotation_cuda.py +5 -65
  54. nabu/misc/transpose.py +6 -0
  55. nabu/misc/unsharp.py +3 -78
  56. nabu/misc/unsharp_cuda.py +5 -52
  57. nabu/misc/unsharp_opencl.py +8 -85
  58. nabu/opencl/fft.py +6 -0
  59. nabu/opencl/kernel.py +21 -6
  60. nabu/opencl/padding.py +5 -72
  61. nabu/opencl/processing.py +27 -5
  62. nabu/opencl/src/backproj.cl +3 -3
  63. nabu/opencl/src/fftshift.cl +65 -12
  64. nabu/opencl/src/padding.cl +2 -2
  65. nabu/opencl/src/roll.cl +96 -0
  66. nabu/opencl/src/transpose.cl +16 -0
  67. nabu/pipeline/config_validators.py +63 -3
  68. nabu/pipeline/dataset_validator.py +2 -2
  69. nabu/pipeline/estimators.py +193 -35
  70. nabu/pipeline/fullfield/chunked.py +34 -17
  71. nabu/pipeline/fullfield/chunked_cuda.py +7 -5
  72. nabu/pipeline/fullfield/computations.py +48 -13
  73. nabu/pipeline/fullfield/nabu_config.py +13 -13
  74. nabu/pipeline/fullfield/processconfig.py +10 -5
  75. nabu/pipeline/fullfield/reconstruction.py +1 -2
  76. nabu/pipeline/helical/fbp.py +5 -0
  77. nabu/pipeline/helical/filtering.py +12 -9
  78. nabu/pipeline/helical/gridded_accumulator.py +179 -33
  79. nabu/pipeline/helical/helical_chunked_regridded.py +262 -151
  80. nabu/pipeline/helical/helical_chunked_regridded_cuda.py +4 -11
  81. nabu/pipeline/helical/helical_reconstruction.py +56 -18
  82. nabu/pipeline/helical/span_strategy.py +1 -1
  83. nabu/pipeline/helical/tests/test_accumulator.py +4 -0
  84. nabu/pipeline/params.py +23 -2
  85. nabu/pipeline/processconfig.py +3 -8
  86. nabu/pipeline/tests/test_chunk_reader.py +78 -0
  87. nabu/pipeline/tests/test_estimators.py +120 -2
  88. nabu/pipeline/utils.py +25 -0
  89. nabu/pipeline/writer.py +2 -0
  90. nabu/preproc/ccd_cuda.py +9 -7
  91. nabu/preproc/ctf.py +21 -26
  92. nabu/preproc/ctf_cuda.py +25 -25
  93. nabu/preproc/double_flatfield.py +14 -2
  94. nabu/preproc/double_flatfield_cuda.py +7 -11
  95. nabu/preproc/flatfield_cuda.py +23 -27
  96. nabu/preproc/phase.py +19 -24
  97. nabu/preproc/phase_cuda.py +21 -21
  98. nabu/preproc/shift_cuda.py +58 -28
  99. nabu/preproc/tests/test_ctf.py +5 -5
  100. nabu/preproc/tests/test_double_flatfield.py +2 -2
  101. nabu/preproc/tests/test_vshift.py +13 -2
  102. nabu/processing/__init__.py +0 -0
  103. nabu/processing/convolution_cuda.py +375 -0
  104. nabu/processing/fft_base.py +163 -0
  105. nabu/processing/fft_cuda.py +256 -0
  106. nabu/processing/fft_opencl.py +54 -0
  107. nabu/processing/fftshift.py +134 -0
  108. nabu/processing/histogram.py +286 -0
  109. nabu/processing/histogram_cuda.py +103 -0
  110. nabu/processing/kernel_base.py +126 -0
  111. nabu/processing/medfilt_cuda.py +159 -0
  112. nabu/processing/muladd.py +29 -0
  113. nabu/processing/muladd_cuda.py +68 -0
  114. nabu/processing/padding_base.py +71 -0
  115. nabu/processing/padding_cuda.py +75 -0
  116. nabu/processing/padding_opencl.py +77 -0
  117. nabu/processing/processing_base.py +123 -0
  118. nabu/processing/roll_opencl.py +64 -0
  119. nabu/processing/rotation.py +63 -0
  120. nabu/processing/rotation_cuda.py +66 -0
  121. nabu/processing/tests/__init__.py +0 -0
  122. nabu/processing/tests/test_fft.py +268 -0
  123. nabu/processing/tests/test_fftshift.py +71 -0
  124. nabu/{misc → processing}/tests/test_histogram.py +2 -4
  125. nabu/{cuda → processing}/tests/test_medfilt.py +1 -1
  126. nabu/processing/tests/test_muladd.py +54 -0
  127. nabu/{cuda → processing}/tests/test_padding.py +119 -75
  128. nabu/processing/tests/test_roll.py +63 -0
  129. nabu/{misc → processing}/tests/test_rotation.py +3 -2
  130. nabu/processing/tests/test_transpose.py +72 -0
  131. nabu/{misc → processing}/tests/test_unsharp.py +41 -8
  132. nabu/processing/transpose.py +126 -0
  133. nabu/processing/unsharp.py +79 -0
  134. nabu/processing/unsharp_cuda.py +53 -0
  135. nabu/processing/unsharp_opencl.py +75 -0
  136. nabu/reconstruction/fbp.py +34 -10
  137. nabu/reconstruction/fbp_base.py +35 -16
  138. nabu/reconstruction/fbp_opencl.py +7 -12
  139. nabu/reconstruction/filtering.py +2 -2
  140. nabu/reconstruction/filtering_cuda.py +13 -14
  141. nabu/reconstruction/filtering_opencl.py +3 -4
  142. nabu/reconstruction/projection.py +2 -0
  143. nabu/reconstruction/rings.py +158 -1
  144. nabu/reconstruction/rings_cuda.py +218 -58
  145. nabu/reconstruction/sinogram_cuda.py +16 -12
  146. nabu/reconstruction/tests/test_deringer.py +116 -14
  147. nabu/reconstruction/tests/test_fbp.py +22 -31
  148. nabu/reconstruction/tests/test_filtering.py +11 -2
  149. nabu/resources/dataset_analyzer.py +89 -26
  150. nabu/resources/nxflatfield.py +2 -2
  151. nabu/resources/tests/test_nxflatfield.py +1 -1
  152. nabu/resources/utils.py +9 -2
  153. nabu/stitching/alignment.py +184 -0
  154. nabu/stitching/config.py +241 -39
  155. nabu/stitching/definitions.py +6 -0
  156. nabu/stitching/frame_composition.py +4 -2
  157. nabu/stitching/overlap.py +99 -3
  158. nabu/stitching/sample_normalization.py +60 -0
  159. nabu/stitching/slurm_utils.py +10 -10
  160. nabu/stitching/tests/test_alignment.py +99 -0
  161. nabu/stitching/tests/test_config.py +16 -1
  162. nabu/stitching/tests/test_overlap.py +68 -2
  163. nabu/stitching/tests/test_sample_normalization.py +49 -0
  164. nabu/stitching/tests/test_slurm_utils.py +5 -5
  165. nabu/stitching/tests/test_utils.py +3 -33
  166. nabu/stitching/tests/test_z_stitching.py +391 -22
  167. nabu/stitching/utils.py +144 -202
  168. nabu/stitching/z_stitching.py +309 -126
  169. nabu/testutils.py +18 -0
  170. nabu/thirdparty/tomocupy_remove_stripe.py +586 -0
  171. nabu/utils.py +32 -6
  172. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/LICENSE +1 -1
  173. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/METADATA +5 -5
  174. nabu-2024.1.0rc3.dist-info/RECORD +296 -0
  175. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/WHEEL +1 -1
  176. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/entry_points.txt +5 -1
  177. nabu/conftest.py +0 -14
  178. nabu/opencl/fftshift.py +0 -92
  179. nabu/opencl/tests/test_fftshift.py +0 -55
  180. nabu/opencl/tests/test_padding.py +0 -84
  181. nabu-2023.2.1.dist-info/RECORD +0 -252
  182. /nabu/cuda/src/{fftshift.cu → dfi_fftshift.cu} +0 -0
  183. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/top_level.txt +0 -0
@@ -1,22 +1,22 @@
1
1
  import os
2
2
  import copy
3
+ from typing import Optional, Union
3
4
  import numpy
4
5
  from silx.io.url import DataUrl
5
- from typing import Iterable, Optional, Union
6
6
  from tomoscan.tomoobject import TomoObject
7
- from nabu.app.bootstrap_stitching import _SECTIONS_COMMENTS
8
- from nabu.stitching.config import (
7
+ from tomoscan.esrf.scan.nxtomoscan import NXtomoScan
8
+ from tomoscan.esrf import EDFTomoScan
9
+ from tomoscan.esrf.volume import HDF5Volume, MultiTIFFVolume
10
+ from tomoscan.esrf.volume.singleframebase import VolumeSingleFrameBase
11
+ from ..app.bootstrap_stitching import _SECTIONS_COMMENTS
12
+ from ..pipeline.config import generate_nabu_configfile
13
+ from .config import (
9
14
  StitchingConfiguration,
10
15
  get_default_stitching_config,
11
16
  PreProcessedZStitchingConfiguration,
12
17
  PostProcessedZStitchingConfiguration,
13
18
  SLURM_SECTION,
14
19
  )
15
- from nabu.pipeline.config import generate_nabu_configfile
16
- from tomoscan.esrf import HDF5TomoScan
17
- from tomoscan.esrf import EDFTomoScan
18
- from tomoscan.esrf.volume import HDF5Volume, MultiTIFFVolume
19
- from tomoscan.esrf.volume.singleframebase import VolumeSingleFrameBase
20
20
 
21
21
  try:
22
22
  from sluurp.job import SBatchScriptJob
@@ -87,7 +87,7 @@ def split_stitching_configuration_to_slurm_job(
87
87
  original_output_file_path,
88
88
  os.path.basename(original_output_file_path) + f"_part_{i_sub_part}" + file_extension,
89
89
  )
90
- output_obj = HDF5TomoScan(
90
+ output_obj = NXtomoScan(
91
91
  scan=sub_configuration.output_file_path,
92
92
  entry=sub_configuration.output_data_path,
93
93
  )
@@ -195,7 +195,7 @@ def get_working_directory(obj: TomoObject) -> Optional[str]:
195
195
  return None
196
196
  elif isinstance(obj, EDFTomoScan):
197
197
  return obj.path
198
- elif isinstance(obj, HDF5TomoScan):
198
+ elif isinstance(obj, NXtomoScan):
199
199
  if obj.master_file is None:
200
200
  return None
201
201
  else:
@@ -0,0 +1,99 @@
1
+ import numpy
2
+ import pytest
3
+ from nabu.stitching.alignment import align_horizontally, PaddedRawData
4
+ from nabu.testutils import get_data
5
+
6
+
7
+ def test_alignment_axis_2():
8
+ """
9
+ test 'align_horizontally' function
10
+ """
11
+ dataset = get_data("chelsea.npz")["data"] # shape is (300, 451)
12
+
13
+ # test if new_width < current_width: should raise an error
14
+ with pytest.raises(ValueError):
15
+ align_horizontally(dataset, alignment="center", new_width=10)
16
+
17
+ # test some use cases
18
+ res = align_horizontally(
19
+ dataset,
20
+ alignment="center",
21
+ new_width=600,
22
+ pad_mode="mean",
23
+ )
24
+ assert res.shape == (300, 600)
25
+ numpy.testing.assert_array_almost_equal(res[:, 74:-75], dataset)
26
+
27
+ res = align_horizontally(
28
+ dataset,
29
+ alignment="left",
30
+ new_width=600,
31
+ pad_mode="median",
32
+ )
33
+ assert res.shape == (300, 600)
34
+ numpy.testing.assert_array_almost_equal(res[:, :451], dataset)
35
+
36
+ res = align_horizontally(
37
+ dataset,
38
+ alignment="right",
39
+ new_width=600,
40
+ pad_mode="reflect",
41
+ )
42
+ assert res.shape == (300, 600)
43
+ numpy.testing.assert_array_almost_equal(res[:, -451:], dataset)
44
+
45
+
46
+ def test_PaddedRawData():
47
+ """
48
+ test PaddedVolume class
49
+ """
50
+ data = numpy.linspace(
51
+ start=0,
52
+ stop=20 * 6 * 3,
53
+ dtype=numpy.int64,
54
+ num=20 * 6 * 3,
55
+ )
56
+ data = data.reshape((3, 6, 20))
57
+
58
+ padded_volume = PaddedRawData(data=data, axis_1_pad_width=(4, 1))
59
+
60
+ assert padded_volume.shape == (3, 6 + 4 + 1, 20)
61
+
62
+ numpy.testing.assert_array_equal(
63
+ padded_volume[:, 0, :],
64
+ numpy.zeros(shape=(3, 1, 20), dtype=numpy.int64),
65
+ )
66
+ numpy.testing.assert_array_equal(
67
+ padded_volume[:, 3, :],
68
+ numpy.zeros(shape=(3, 1, 20), dtype=numpy.int64),
69
+ )
70
+ numpy.testing.assert_array_equal(
71
+ padded_volume[:, 10, :],
72
+ numpy.zeros(shape=(3, 1, 20), dtype=numpy.int64),
73
+ )
74
+ assert padded_volume[:, 3, :].shape == (3, 1, 20)
75
+ numpy.testing.assert_array_equal(
76
+ padded_volume[:, 4, :],
77
+ data[:, 0:1, :], # TODO: have a look, return a 3D array when a 2D expected...
78
+ )
79
+
80
+ with pytest.raises(ValueError):
81
+ padded_volume[:, 40, :]
82
+ with pytest.raises(ValueError):
83
+ padded_volume[:, 5:1, :]
84
+
85
+ arrays = (
86
+ numpy.zeros(shape=(3, 4, 20), dtype=numpy.int64),
87
+ data,
88
+ numpy.zeros(shape=(3, 1, 20), dtype=numpy.int64),
89
+ )
90
+ expected_volume = numpy.hstack(
91
+ arrays,
92
+ )
93
+ assert padded_volume[:, :, :].shape == padded_volume.shape
94
+ assert expected_volume.shape == padded_volume.shape
95
+
96
+ numpy.testing.assert_array_equal(
97
+ padded_volume[:, :, :],
98
+ expected_volume,
99
+ )
@@ -113,7 +113,14 @@ def test_stitching_config(stitching_type, option_level):
113
113
  @pytest.mark.parametrize("overwrite_results", (True, "False", 0, "1"))
114
114
  @pytest.mark.parametrize(
115
115
  "axis_shifts",
116
- ("", None, "None", "", "skimage", "nabu-fft", "shift-grid"),
116
+ (
117
+ "",
118
+ None,
119
+ "None",
120
+ "",
121
+ "skimage",
122
+ "nabu-fft",
123
+ ),
117
124
  )
118
125
  @pytest.mark.parametrize("axis_shifts_params", ("", {}, "window_size=200"))
119
126
  @pytest.mark.parametrize(
@@ -135,6 +142,7 @@ def test_stitching_config(stitching_type, option_level):
135
142
  "slurm_config",
136
143
  (
137
144
  {
145
+ stiching_config.SLURM_MODULES_TO_LOADS: "tomotools",
138
146
  stiching_config.SLURM_PREPROCESSING_COMMAND: "",
139
147
  stiching_config.SLURM_CLEAN_SCRIPTS: True,
140
148
  stiching_config.SLURM_MEM: 56,
@@ -181,6 +189,13 @@ def test_PreProcessedZStitchingConfiguration(
181
189
  stiching_config.NEXUS_VERSION_FIELD: None,
182
190
  },
183
191
  stiching_config.SLURM_SECTION: slurm_config,
192
+ stiching_config.NORMALIZATION_BY_SAMPLE_SECTION: {
193
+ stiching_config.NORMALIZATION_BY_SAMPLE_MARGIN: 1,
194
+ stiching_config.NORMALIZATION_BY_SAMPLE_SIDE: "right",
195
+ stiching_config.NORMALIZATION_BY_SAMPLE_ACTIVE_FIELD: True,
196
+ stiching_config.NORMALIZATION_BY_SAMPLE_METHOD: "mean",
197
+ stiching_config.NORMALIZATION_BY_SAMPLE_WIDTH: 31,
198
+ },
184
199
  },
185
200
  )
186
201
 
@@ -1,6 +1,8 @@
1
- from nabu.stitching.overlap import compute_image_minimum_divergence, compute_image_higher_signal
2
- from nabu.testutils import get_data
3
1
  import numpy
2
+ import pytest
3
+
4
+ from nabu.stitching.overlap import compute_image_minimum_divergence, compute_image_higher_signal, check_overlaps
5
+ from nabu.testutils import get_data
4
6
 
5
7
 
6
8
  def test_compute_image_minimum_divergence():
@@ -30,3 +32,67 @@ def test_compute_image_higher_signal():
30
32
  stitching,
31
33
  raw_data,
32
34
  )
35
+
36
+
37
+ def test_check_overlaps():
38
+ """test 'check_overlaps' function"""
39
+
40
+ # two frames, ordered and with an overlap
41
+ check_overlaps(
42
+ frames=(
43
+ numpy.ones(10),
44
+ numpy.ones(20),
45
+ ),
46
+ positions=((10,), (0,)),
47
+ axis=0,
48
+ raise_error=True,
49
+ )
50
+
51
+ # two frames, ordered and without an overlap
52
+ with pytest.raises(ValueError):
53
+ check_overlaps(
54
+ frames=(
55
+ numpy.ones(10),
56
+ numpy.ones(20),
57
+ ),
58
+ positions=((0,), (100,)),
59
+ axis=0,
60
+ raise_error=True,
61
+ )
62
+
63
+ # two frames, frame 0 fully overlap frame 1
64
+ with pytest.raises(ValueError):
65
+ check_overlaps(
66
+ frames=(
67
+ numpy.ones(20),
68
+ numpy.ones(10),
69
+ ),
70
+ positions=((8,), (5,)),
71
+ axis=0,
72
+ raise_error=True,
73
+ )
74
+
75
+ # three frames 'overlaping' as expected
76
+ check_overlaps(
77
+ frames=(
78
+ numpy.ones(10),
79
+ numpy.ones(20),
80
+ numpy.ones(10),
81
+ ),
82
+ positions=((20,), (10,), (0,)),
83
+ axis=0,
84
+ raise_error=True,
85
+ )
86
+
87
+ # three frames: frame 0 overlap frame 1 but also frame 2
88
+ with pytest.raises(ValueError):
89
+ check_overlaps(
90
+ frames=(
91
+ numpy.ones(20),
92
+ numpy.ones(10),
93
+ numpy.ones(10),
94
+ ),
95
+ positions=((20,), (15,), (11,)),
96
+ axis=0,
97
+ raise_error=True,
98
+ )
@@ -0,0 +1,49 @@
1
+ import numpy
2
+ import pytest
3
+ from nabu.stitching.sample_normalization import normalize_frame, SampleSide, Method
4
+
5
+
6
+ def test_normalize_frame():
7
+ """
8
+ test normalize_frame function
9
+ """
10
+ with pytest.raises(TypeError):
11
+ normalize_frame("toto", "left", "median")
12
+ with pytest.raises(TypeError):
13
+ normalize_frame(numpy.linspace(0, 100), "left", "median")
14
+
15
+ frame = numpy.ones((10, 40))
16
+ frame[:, 15:25] = numpy.arange(1, 101, step=1).reshape((10, 10))
17
+
18
+ numpy.testing.assert_array_equal(
19
+ normalize_frame(
20
+ frame=frame,
21
+ side="left",
22
+ method="mean",
23
+ sample_width=10,
24
+ margin_before_sample=2,
25
+ )[:, 15:25],
26
+ numpy.arange(0, 100, step=1).reshape((10, 10)),
27
+ )
28
+
29
+ numpy.testing.assert_array_equal(
30
+ normalize_frame(
31
+ frame=frame,
32
+ side="right",
33
+ method="median",
34
+ sample_width=10,
35
+ margin_before_sample=2,
36
+ )[:, 15:25],
37
+ numpy.arange(0, 100, step=1).reshape((10, 10)),
38
+ )
39
+
40
+ assert not numpy.array_equal(
41
+ normalize_frame(
42
+ frame=frame,
43
+ side="right",
44
+ method="mean",
45
+ sample_width=10,
46
+ margin_before_sample=20,
47
+ )[:, 15:25],
48
+ numpy.arange(0, 100, step=1).reshape((10, 10)),
49
+ )
@@ -1,7 +1,7 @@
1
1
  import os
2
2
  import numpy
3
3
  import pytest
4
- from tomoscan.esrf import HDF5TomoScan
4
+ from tomoscan.esrf import NXtomoScan
5
5
  from tomoscan.esrf.volume import HDF5Volume
6
6
  from tomoscan.esrf.scan.utils import cwd_context
7
7
  from nabu.stitching.config import PreProcessedZStitchingConfiguration, SlurmConfig
@@ -11,7 +11,7 @@ from nabu.stitching.slurm_utils import (
11
11
  get_working_directory,
12
12
  split_stitching_configuration_to_slurm_job,
13
13
  )
14
- from tomoscan.esrf.mock import MockHDF5
14
+ from tomoscan.esrf.mock import MockNXtomo
15
15
 
16
16
  try:
17
17
  import sluurp
@@ -59,7 +59,7 @@ def test_split_slices():
59
59
 
60
60
  def test_get_working_directory():
61
61
  """test get_working_directory function"""
62
- assert get_working_directory(HDF5TomoScan("/this/is/my/hdf5file.hdf5", "entry")) == "/this/is/my"
62
+ assert get_working_directory(NXtomoScan("/this/is/my/hdf5file.hdf5", "entry")) == "/this/is/my"
63
63
  assert get_working_directory(HDF5Volume("/this/is/my/volume.hdf5", "entry")) == "/this/is/my"
64
64
 
65
65
 
@@ -85,14 +85,14 @@ def test_split_stitching_configuration_to_slurm_job(tmp_path):
85
85
  with cwd_context(inputs_dir):
86
86
  # the current working directory context help to check file path are moved to absolute.
87
87
  # which is important because those jobs will be launched on slurm
88
- scan_1 = MockHDF5(
88
+ scan_1 = MockNXtomo(
89
89
  os.path.join("scan_1"),
90
90
  n_proj=10,
91
91
  n_ini_proj=10,
92
92
  dim=100,
93
93
  ).scan
94
94
 
95
- scan_2 = MockHDF5(
95
+ scan_2 = MockNXtomo(
96
96
  os.path.join("scan_2"),
97
97
  n_proj=10,
98
98
  n_ini_proj=10,
@@ -1,45 +1,15 @@
1
- from nabu.stitching.utils import shift_grid_search, ScoreMethod, has_itk, find_shift_with_itk
2
- import scipy.misc
3
- from scipy.ndimage import shift as scipy_shift
1
+ from nabu.stitching.utils import has_itk, find_shift_with_itk
4
2
  from scipy.ndimage import shift as shift_scipy
5
3
  import numpy
6
4
  import pytest
7
-
8
-
9
- @pytest.mark.parametrize("shift", [(0, 3), (-4, 0), (8, -6)])
10
- def test_shift_grid_search(shift):
11
- """
12
- test shift_grid_search algorithm
13
- """
14
- y = numpy.sin(numpy.linspace(0, numpy.pi, 250))
15
-
16
- weights = numpy.cos(numpy.linspace(-numpy.pi, numpy.pi, 120))
17
- image = numpy.outer(y, weights)
18
- # add a simple line with different value to ease detection
19
- image[150] = 1.0
20
- image[:, 50] = 1.2
21
-
22
- image_ref = image
23
- image_with_shift = scipy_shift(image_ref.copy(), shift=-numpy.array(shift))
24
- score_method = ScoreMethod.TV
25
-
26
- best_shift = shift_grid_search(
27
- image_ref,
28
- image_with_shift,
29
- window_sizes=(40, 20),
30
- axis=(0, 1),
31
- step_size=1,
32
- score_method=score_method,
33
- )
34
-
35
- assert tuple(best_shift) == shift
5
+ from nabu.testutils import get_data
36
6
 
37
7
 
38
8
  @pytest.mark.parametrize("data_type", (numpy.float32, numpy.uint16))
39
9
  @pytest.mark.skipif(not has_itk, reason="itk not installed")
40
10
  def test_find_shift_with_itk(data_type):
41
11
  shift = (5, 2)
42
- img1 = scipy.misc.ascent().astype(data_type)
12
+ img1 = get_data("chelsea.npz")["data"].astype(data_type)
43
13
  img2 = shift_scipy(
44
14
  img1.copy(),
45
15
  shift=shift,