nabu 2023.2.1__py3-none-any.whl → 2024.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (183) hide show
  1. doc/conf.py +1 -1
  2. doc/doc_config.py +32 -0
  3. nabu/__init__.py +2 -1
  4. nabu/app/bootstrap_stitching.py +1 -1
  5. nabu/app/cli_configs.py +122 -2
  6. nabu/app/composite_cor.py +27 -2
  7. nabu/app/correct_rot.py +70 -0
  8. nabu/app/create_distortion_map_from_poly.py +42 -18
  9. nabu/app/diag_to_pix.py +358 -0
  10. nabu/app/diag_to_rot.py +449 -0
  11. nabu/app/generate_header.py +4 -3
  12. nabu/app/histogram.py +2 -2
  13. nabu/app/multicor.py +6 -1
  14. nabu/app/parse_reconstruction_log.py +151 -0
  15. nabu/app/prepare_weights_double.py +83 -22
  16. nabu/app/reconstruct.py +5 -1
  17. nabu/app/reconstruct_helical.py +7 -0
  18. nabu/app/reduce_dark_flat.py +6 -3
  19. nabu/app/rotate.py +4 -4
  20. nabu/app/stitching.py +16 -2
  21. nabu/app/tests/test_reduce_dark_flat.py +18 -2
  22. nabu/app/validator.py +4 -4
  23. nabu/cuda/convolution.py +8 -376
  24. nabu/cuda/fft.py +4 -0
  25. nabu/cuda/kernel.py +4 -4
  26. nabu/cuda/medfilt.py +5 -158
  27. nabu/cuda/padding.py +5 -71
  28. nabu/cuda/processing.py +23 -2
  29. nabu/cuda/src/ElementOp.cu +78 -0
  30. nabu/cuda/src/backproj.cu +28 -2
  31. nabu/cuda/src/fourier_wavelets.cu +2 -2
  32. nabu/cuda/src/normalization.cu +23 -0
  33. nabu/cuda/src/padding.cu +2 -2
  34. nabu/cuda/src/transpose.cu +16 -0
  35. nabu/cuda/utils.py +39 -0
  36. nabu/estimation/alignment.py +10 -1
  37. nabu/estimation/cor.py +808 -38
  38. nabu/estimation/cor_sino.py +7 -9
  39. nabu/estimation/tests/test_cor.py +85 -3
  40. nabu/io/reader.py +26 -18
  41. nabu/io/tests/test_cast_volume.py +3 -3
  42. nabu/io/tests/test_detector_distortion.py +3 -3
  43. nabu/io/tiffwriter_zmm.py +2 -2
  44. nabu/io/utils.py +14 -4
  45. nabu/io/writer.py +5 -3
  46. nabu/misc/fftshift.py +6 -0
  47. nabu/misc/histogram.py +5 -285
  48. nabu/misc/histogram_cuda.py +8 -104
  49. nabu/misc/kernel_base.py +3 -121
  50. nabu/misc/padding_base.py +5 -69
  51. nabu/misc/processing_base.py +3 -107
  52. nabu/misc/rotation.py +5 -62
  53. nabu/misc/rotation_cuda.py +5 -65
  54. nabu/misc/transpose.py +6 -0
  55. nabu/misc/unsharp.py +3 -78
  56. nabu/misc/unsharp_cuda.py +5 -52
  57. nabu/misc/unsharp_opencl.py +8 -85
  58. nabu/opencl/fft.py +6 -0
  59. nabu/opencl/kernel.py +21 -6
  60. nabu/opencl/padding.py +5 -72
  61. nabu/opencl/processing.py +27 -5
  62. nabu/opencl/src/backproj.cl +3 -3
  63. nabu/opencl/src/fftshift.cl +65 -12
  64. nabu/opencl/src/padding.cl +2 -2
  65. nabu/opencl/src/roll.cl +96 -0
  66. nabu/opencl/src/transpose.cl +16 -0
  67. nabu/pipeline/config_validators.py +63 -3
  68. nabu/pipeline/dataset_validator.py +2 -2
  69. nabu/pipeline/estimators.py +193 -35
  70. nabu/pipeline/fullfield/chunked.py +34 -17
  71. nabu/pipeline/fullfield/chunked_cuda.py +7 -5
  72. nabu/pipeline/fullfield/computations.py +48 -13
  73. nabu/pipeline/fullfield/nabu_config.py +13 -13
  74. nabu/pipeline/fullfield/processconfig.py +10 -5
  75. nabu/pipeline/fullfield/reconstruction.py +1 -2
  76. nabu/pipeline/helical/fbp.py +5 -0
  77. nabu/pipeline/helical/filtering.py +12 -9
  78. nabu/pipeline/helical/gridded_accumulator.py +179 -33
  79. nabu/pipeline/helical/helical_chunked_regridded.py +262 -151
  80. nabu/pipeline/helical/helical_chunked_regridded_cuda.py +4 -11
  81. nabu/pipeline/helical/helical_reconstruction.py +56 -18
  82. nabu/pipeline/helical/span_strategy.py +1 -1
  83. nabu/pipeline/helical/tests/test_accumulator.py +4 -0
  84. nabu/pipeline/params.py +23 -2
  85. nabu/pipeline/processconfig.py +3 -8
  86. nabu/pipeline/tests/test_chunk_reader.py +78 -0
  87. nabu/pipeline/tests/test_estimators.py +120 -2
  88. nabu/pipeline/utils.py +25 -0
  89. nabu/pipeline/writer.py +2 -0
  90. nabu/preproc/ccd_cuda.py +9 -7
  91. nabu/preproc/ctf.py +21 -26
  92. nabu/preproc/ctf_cuda.py +25 -25
  93. nabu/preproc/double_flatfield.py +14 -2
  94. nabu/preproc/double_flatfield_cuda.py +7 -11
  95. nabu/preproc/flatfield_cuda.py +23 -27
  96. nabu/preproc/phase.py +19 -24
  97. nabu/preproc/phase_cuda.py +21 -21
  98. nabu/preproc/shift_cuda.py +58 -28
  99. nabu/preproc/tests/test_ctf.py +5 -5
  100. nabu/preproc/tests/test_double_flatfield.py +2 -2
  101. nabu/preproc/tests/test_vshift.py +13 -2
  102. nabu/processing/__init__.py +0 -0
  103. nabu/processing/convolution_cuda.py +375 -0
  104. nabu/processing/fft_base.py +163 -0
  105. nabu/processing/fft_cuda.py +256 -0
  106. nabu/processing/fft_opencl.py +54 -0
  107. nabu/processing/fftshift.py +134 -0
  108. nabu/processing/histogram.py +286 -0
  109. nabu/processing/histogram_cuda.py +103 -0
  110. nabu/processing/kernel_base.py +126 -0
  111. nabu/processing/medfilt_cuda.py +159 -0
  112. nabu/processing/muladd.py +29 -0
  113. nabu/processing/muladd_cuda.py +68 -0
  114. nabu/processing/padding_base.py +71 -0
  115. nabu/processing/padding_cuda.py +75 -0
  116. nabu/processing/padding_opencl.py +77 -0
  117. nabu/processing/processing_base.py +123 -0
  118. nabu/processing/roll_opencl.py +64 -0
  119. nabu/processing/rotation.py +63 -0
  120. nabu/processing/rotation_cuda.py +66 -0
  121. nabu/processing/tests/__init__.py +0 -0
  122. nabu/processing/tests/test_fft.py +268 -0
  123. nabu/processing/tests/test_fftshift.py +71 -0
  124. nabu/{misc → processing}/tests/test_histogram.py +2 -4
  125. nabu/{cuda → processing}/tests/test_medfilt.py +1 -1
  126. nabu/processing/tests/test_muladd.py +54 -0
  127. nabu/{cuda → processing}/tests/test_padding.py +119 -75
  128. nabu/processing/tests/test_roll.py +63 -0
  129. nabu/{misc → processing}/tests/test_rotation.py +3 -2
  130. nabu/processing/tests/test_transpose.py +72 -0
  131. nabu/{misc → processing}/tests/test_unsharp.py +41 -8
  132. nabu/processing/transpose.py +126 -0
  133. nabu/processing/unsharp.py +79 -0
  134. nabu/processing/unsharp_cuda.py +53 -0
  135. nabu/processing/unsharp_opencl.py +75 -0
  136. nabu/reconstruction/fbp.py +34 -10
  137. nabu/reconstruction/fbp_base.py +35 -16
  138. nabu/reconstruction/fbp_opencl.py +7 -12
  139. nabu/reconstruction/filtering.py +2 -2
  140. nabu/reconstruction/filtering_cuda.py +13 -14
  141. nabu/reconstruction/filtering_opencl.py +3 -4
  142. nabu/reconstruction/projection.py +2 -0
  143. nabu/reconstruction/rings.py +158 -1
  144. nabu/reconstruction/rings_cuda.py +218 -58
  145. nabu/reconstruction/sinogram_cuda.py +16 -12
  146. nabu/reconstruction/tests/test_deringer.py +116 -14
  147. nabu/reconstruction/tests/test_fbp.py +22 -31
  148. nabu/reconstruction/tests/test_filtering.py +11 -2
  149. nabu/resources/dataset_analyzer.py +89 -26
  150. nabu/resources/nxflatfield.py +2 -2
  151. nabu/resources/tests/test_nxflatfield.py +1 -1
  152. nabu/resources/utils.py +9 -2
  153. nabu/stitching/alignment.py +184 -0
  154. nabu/stitching/config.py +241 -39
  155. nabu/stitching/definitions.py +6 -0
  156. nabu/stitching/frame_composition.py +4 -2
  157. nabu/stitching/overlap.py +99 -3
  158. nabu/stitching/sample_normalization.py +60 -0
  159. nabu/stitching/slurm_utils.py +10 -10
  160. nabu/stitching/tests/test_alignment.py +99 -0
  161. nabu/stitching/tests/test_config.py +16 -1
  162. nabu/stitching/tests/test_overlap.py +68 -2
  163. nabu/stitching/tests/test_sample_normalization.py +49 -0
  164. nabu/stitching/tests/test_slurm_utils.py +5 -5
  165. nabu/stitching/tests/test_utils.py +3 -33
  166. nabu/stitching/tests/test_z_stitching.py +391 -22
  167. nabu/stitching/utils.py +144 -202
  168. nabu/stitching/z_stitching.py +309 -126
  169. nabu/testutils.py +18 -0
  170. nabu/thirdparty/tomocupy_remove_stripe.py +586 -0
  171. nabu/utils.py +32 -6
  172. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/LICENSE +1 -1
  173. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/METADATA +5 -5
  174. nabu-2024.1.0rc3.dist-info/RECORD +296 -0
  175. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/WHEEL +1 -1
  176. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/entry_points.txt +5 -1
  177. nabu/conftest.py +0 -14
  178. nabu/opencl/fftshift.py +0 -92
  179. nabu/opencl/tests/test_fftshift.py +0 -55
  180. nabu/opencl/tests/test_padding.py +0 -84
  181. nabu-2023.2.1.dist-info/RECORD +0 -252
  182. /nabu/cuda/src/{fftshift.cu → dfi_fftshift.cu} +0 -0
  183. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/top_level.txt +0 -0
nabu/testutils.py CHANGED
@@ -59,6 +59,24 @@ def get_data(*dataset_path):
59
59
  return np.load(dataset_downloaded_path)
60
60
 
61
61
 
62
+ def get_array_of_given_shape(img, shape, dtype):
63
+ """
64
+ From a given image, returns an array of the wanted shape and dtype.
65
+ """
66
+
67
+ # Tile image until it's big enough.
68
+ # "fun" fact: using any(blabla) crashes but using any([blabla]) does not, because of variables re-evaluation
69
+ while any([i_dim <= s_dim for i_dim, s_dim in zip(img.shape, shape)]):
70
+ img = np.tile(img, (2, 2))
71
+ if len(shape) == 1:
72
+ arr = img[: shape[0], 0]
73
+ elif len(shape) == 2:
74
+ arr = img[: shape[0], : shape[1]]
75
+ else:
76
+ arr = np.tile(img, (shape[0], 1, 1))[: shape[0], : shape[1], : shape[2]]
77
+ return np.ascontiguousarray(np.squeeze(arr), dtype=dtype)
78
+
79
+
62
80
  def get_big_data(filename):
63
81
  if __big_testdata_dir__ is None:
64
82
  return None
@@ -0,0 +1,586 @@
1
+ # pylint: skip-file
2
+
3
+ """
4
+ This file is a "GPU" (through cupy) implementation of "remove_all_stripe".
5
+ The original method is implemented by Nghia Vo in the algotom project: https://github.com/algotom/algotom/blob/master/algotom/prep/removal.py
6
+ The implementation using cupy is done by Viktor Nikitin in the tomocupy project: https://github.com/tomography/tomocupy/blame/main/src/tomocupy/remove_stripe.py
7
+ License follows.
8
+
9
+ For now we can't rely on off-the-shelf tomocupy as it's not packaged in pypi, and compilation is quite tedious.
10
+ """
11
+
12
+ # *************************************************************************** #
13
+ # Copyright © 2022, UChicago Argonne, LLC #
14
+ # All Rights Reserved #
15
+ # Software Name: Tomocupy #
16
+ # By: Argonne National Laboratory #
17
+ # #
18
+ # OPEN SOURCE LICENSE #
19
+ # #
20
+ # Redistribution and use in source and binary forms, with or without #
21
+ # modification, are permitted provided that the following conditions are met: #
22
+ # #
23
+ # 1. Redistributions of source code must retain the above copyright notice, #
24
+ # this list of conditions and the following disclaimer. #
25
+ # 2. Redistributions in binary form must reproduce the above copyright #
26
+ # notice, this list of conditions and the following disclaimer in the #
27
+ # documentation and/or other materials provided with the distribution. #
28
+ # 3. Neither the name of the copyright holder nor the names of its #
29
+ # contributors may be used to endorse or promote products derived #
30
+ # from this software without specific prior written permission. #
31
+ # #
32
+ # #
33
+ # *************************************************************************** #
34
+ # DISCLAIMER #
35
+ # #
36
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS #
37
+ # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT #
38
+ # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS #
39
+ # FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT #
40
+ # HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, #
41
+ # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED #
42
+ # TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR #
43
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF #
44
+ # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING #
45
+ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS #
46
+ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #
47
+ # *************************************************************************** #
48
+
49
+ try:
50
+ import cupy as cp
51
+ import pywt
52
+ from cupyx.scipy.ndimage import median_filter
53
+ from cupyx.scipy import signal
54
+ from cupyx.scipy.ndimage import binary_dilation
55
+ from cupyx.scipy.ndimage import uniform_filter1d
56
+ __have_tomocupy_deringer__ = True
57
+ except ImportError as err:
58
+ __have_tomocupy_deringer__ = False
59
+ __tomocupy_deringer_import_error__ = err
60
+
61
+
62
+ ###### Ring removal with wavelet filtering (adapted for cupy from pytroch_wavelet package https://pytorch-wavelets.readthedocs.io/)################################################################################
63
+
64
+ def _reflect(x, minx, maxx):
65
+ """Reflect the values in matrix *x* about the scalar values *minx* and
66
+ *maxx*. Hence a vector *x* containing a long linearly increasing series is
67
+ converted into a waveform which ramps linearly up and down between *minx*
68
+ and *maxx*. If *x* contains integers and *minx* and *maxx* are (integers +
69
+ 0.5), the ramps will have repeated max and min samples.
70
+
71
+ .. codeauthor:: Rich Wareham <rjw57@cantab.net>, Aug 2013
72
+ .. codeauthor:: Nick Kingsbury, Cambridge University, January 1999.
73
+
74
+ """
75
+ x = cp.asanyarray(x)
76
+ rng = maxx - minx
77
+ rng_by_2 = 2 * rng
78
+ mod = cp.fmod(x - minx, rng_by_2)
79
+ normed_mod = cp.where(mod < 0, mod + rng_by_2, mod)
80
+ out = cp.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx
81
+ return cp.array(out, dtype=x.dtype)
82
+
83
+
84
+ def _mypad(x, pad, value=0):
85
+ """ Function to do numpy like padding on Arrays. Only works for 2-D
86
+ padding.
87
+
88
+ Inputs:
89
+ x (array): Array to pad
90
+ pad (tuple): tuple of (left, right, top, bottom) pad sizes
91
+ """
92
+ # Vertical only
93
+ if pad[0] == 0 and pad[1] == 0:
94
+ m1, m2 = pad[2], pad[3]
95
+ l = x.shape[-2]
96
+ xe = _reflect(cp.arange(-m1, l+m2, dtype='int32'), -0.5, l-0.5)
97
+ return x[:, :, xe]
98
+ # horizontal only
99
+ elif pad[2] == 0 and pad[3] == 0:
100
+ m1, m2 = pad[0], pad[1]
101
+ l = x.shape[-1]
102
+ xe = _reflect(cp.arange(-m1, l+m2, dtype='int32'), -0.5, l-0.5)
103
+ return x[:, :, :, xe]
104
+
105
+
106
+ def _conv2d(x, w, stride, pad, groups=1):
107
+ """ Convolution (equivalent pytorch.conv2d)
108
+ """
109
+ if pad != 0:
110
+ x = cp.pad(x, ((0, 0), (0, 0), (pad, pad), (pad, pad)), 'constant')
111
+
112
+ b, ci, hi, wi = x.shape
113
+ co, _, hk, wk = w.shape
114
+ ho = int(cp.floor(1 + (hi - hk) / stride[0]))
115
+ wo = int(cp.floor(1 + (wi - wk) / stride[1]))
116
+ out = cp.zeros([b, co, ho, wo], dtype='float32')
117
+ x = cp.expand_dims(x, axis=1)
118
+ w = cp.expand_dims(w, axis=0)
119
+ chunk = ci//groups
120
+ chunko = co//groups
121
+ for g in range(groups):
122
+ for ii in range(hk):
123
+ for jj in range(wk):
124
+ x_windows = x[:, :, g*chunk:(g+1)*chunk, ii:ho *
125
+ stride[0]+ii:stride[0], jj:wo*stride[1]+jj:stride[1]]
126
+ out[:, g*chunko:(g+1)*chunko] += cp.sum(x_windows *
127
+ w[:, g*chunko:(g+1)*chunko, :, ii:ii+1, jj:jj+1], axis=2)
128
+ return out
129
+
130
+
131
+ def _conv_transpose2d(x, w, stride, pad, bias=None, groups=1):
132
+ """ Transposed convolution (equivalent pytorch.conv_transpose2d)
133
+ """
134
+ b, co, ho, wo = x.shape
135
+ co, ci, hk, wk = w.shape
136
+
137
+ hi = (ho-1)*stride[0]+hk
138
+ wi = (wo-1)*stride[1]+wk
139
+ out = cp.zeros([b, ci, hi, wi], dtype='float32')
140
+ chunk = ci//groups
141
+ chunko = co//groups
142
+ for g in range(groups):
143
+ for ii in range(hk):
144
+ for jj in range(wk):
145
+ x_windows = x[:, g*chunko:(g+1)*chunko]
146
+ out[:, g*chunk:(g+1)*chunk, ii:ho*stride[0]+ii:stride[0], jj:wo*stride[1] +
147
+ jj:stride[1]] += x_windows * w[g*chunko:(g+1)*chunko, :, ii:ii+1, jj:jj+1]
148
+ if pad != 0:
149
+ out = out[:, :, pad[0]:out.shape[2]-pad[0], pad[1]:out.shape[3]-pad[1]]
150
+ return out
151
+
152
+
153
+ def afb1d(x, h0, h1='zero', dim=-1):
154
+ """ 1D analysis filter bank (along one dimension only) of an image
155
+
156
+ Parameters
157
+ ----------
158
+ x (array): 4D input with the last two dimensions the spatial input
159
+ h0 (array): 4D input for the lowpass filter. Should have shape (1, 1,
160
+ h, 1) or (1, 1, 1, w)
161
+ h1 (array): 4D input for the highpass filter. Should have shape (1, 1,
162
+ h, 1) or (1, 1, 1, w)
163
+ dim (int) - dimension of filtering. d=2 is for a vertical filter (called
164
+ column filtering but filters across the rows). d=3 is for a
165
+ horizontal filter, (called row filtering but filters across the
166
+ columns).
167
+
168
+ Returns
169
+ -------
170
+ lohi: lowpass and highpass subbands concatenated along the channel
171
+ dimension
172
+ """
173
+ C = x.shape[1]
174
+ # Convert the dim to positive
175
+ d = dim % 4
176
+ s = (2, 1) if d == 2 else (1, 2)
177
+ N = x.shape[d]
178
+ L = h0.size
179
+ L2 = L // 2
180
+ shape = [1, 1, 1, 1]
181
+ shape[d] = L
182
+ h = cp.concatenate([h0.reshape(*shape), h1.reshape(*shape)]*C, axis=0)
183
+ # Calculate the pad size
184
+ outsize = pywt.dwt_coeff_len(N, L, mode='symmetric')
185
+ p = 2 * (outsize - 1) - N + L
186
+ pad = (0, 0, p//2, (p+1)//2) if d == 2 else (p//2, (p+1)//2, 0, 0)
187
+ x = _mypad(x, pad=pad)
188
+ lohi = _conv2d(x, h, stride=s, pad=0, groups=C)
189
+ return lohi
190
+
191
+
192
+ def sfb1d(lo, hi, g0, g1='zero', dim=-1):
193
+ """ 1D synthesis filter bank of an image Array
194
+ """
195
+
196
+ C = lo.shape[1]
197
+ d = dim % 4
198
+ L = g0.size
199
+ shape = [1, 1, 1, 1]
200
+ shape[d] = L
201
+ N = 2*lo.shape[d]
202
+ s = (2, 1) if d == 2 else (1, 2)
203
+ g0 = cp.concatenate([g0.reshape(*shape)]*C, axis=0)
204
+ g1 = cp.concatenate([g1.reshape(*shape)]*C, axis=0)
205
+ pad = (L-2, 0) if d == 2 else (0, L-2)
206
+ y = _conv_transpose2d(cp.asarray(lo), cp.asarray(g0), stride=s, pad=pad, groups=C) + \
207
+ _conv_transpose2d(cp.asarray(hi), cp.asarray(g1),
208
+ stride=s, pad=pad, groups=C)
209
+ return y
210
+
211
+
212
+ class DWTForward():
213
+ """ Performs a 2d DWT Forward decomposition of an image
214
+
215
+ Args:
216
+ wave (str): Which wavelet to use.
217
+ """
218
+
219
+ def __init__(self, wave='db1'):
220
+ super().__init__()
221
+
222
+ wave = pywt.Wavelet(wave)
223
+ h0_col, h1_col = wave.dec_lo, wave.dec_hi
224
+ h0_row, h1_row = h0_col, h1_col
225
+
226
+ self.h0_col = cp.array(h0_col).astype('float32')[
227
+ ::-1].reshape((1, 1, -1, 1))
228
+ self.h1_col = cp.array(h1_col).astype('float32')[
229
+ ::-1].reshape((1, 1, -1, 1))
230
+ self.h0_row = cp.array(h0_row).astype('float32')[
231
+ ::-1].reshape((1, 1, 1, -1))
232
+ self.h1_row = cp.array(h1_row).astype('float32')[
233
+ ::-1].reshape((1, 1, 1, -1))
234
+
235
+ def apply(self, x):
236
+ """ Forward pass of the DWT.
237
+
238
+ Args:
239
+ x (array): Input of shape :math:`(N, C_{in}, H_{in}, W_{in})`
240
+
241
+ Returns:
242
+ (yl, yh)
243
+ tuple of lowpass (yl) and bandpass (yh) coefficients.
244
+ yh is a list of scale coefficients. yl has shape
245
+ :math:`(N, C_{in}, H_{in}', W_{in}')` and yh has shape
246
+ :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. The new
247
+ dimension in yh iterates over the LH, HL and HH coefficients.
248
+
249
+ Note:
250
+ :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly
251
+ downsampled shapes of the DWT pyramid.
252
+ """
253
+ # Do a multilevel transform
254
+ # Do 1 level of the transform
255
+ lohi = afb1d(x, self.h0_row, self.h1_row, dim=3)
256
+ y = afb1d(lohi, self.h0_col, self.h1_col, dim=2)
257
+ s = y.shape
258
+ y = y.reshape(s[0], -1, 4, s[-2], s[-1]) # pylint: disable=E1121 # this might blow up in the future
259
+ x = cp.ascontiguousarray(y[:, :, 0])
260
+ yh = cp.ascontiguousarray(y[:, :, 1:])
261
+ return x, yh
262
+
263
+
264
+ class DWTInverse():
265
+ """ Performs a 2d DWT Inverse reconstruction of an image
266
+
267
+ Args:
268
+ wave (str): Which wavelet to use.
269
+ """
270
+
271
+ def __init__(self, wave='db1'):
272
+ super().__init__()
273
+ wave = pywt.Wavelet(wave)
274
+ g0_col, g1_col = wave.rec_lo, wave.rec_hi
275
+ g0_row, g1_row = g0_col, g1_col
276
+ # Prepare the filters
277
+ self.g0_col = cp.array(g0_col).astype('float32').reshape((1, 1, -1, 1))
278
+ self.g1_col = cp.array(g1_col).astype('float32').reshape((1, 1, -1, 1))
279
+ self.g0_row = cp.array(g0_row).astype('float32').reshape((1, 1, 1, -1))
280
+ self.g1_row = cp.array(g1_row).astype('float32').reshape((1, 1, 1, -1))
281
+
282
+ def apply(self, coeffs):
283
+ """
284
+ Args:
285
+ coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where:
286
+ yl is a lowpass array of shape :math:`(N, C_{in}, H_{in}',
287
+ W_{in}')` and yh is a list of bandpass arrays of shape
288
+ :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match
289
+ the format returned by DWTForward
290
+
291
+ Returns:
292
+ Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})`
293
+
294
+ Note:
295
+ :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly
296
+ downsampled shapes of the DWT pyramid.
297
+
298
+ """
299
+ yl, yh = coeffs
300
+ lh = yh[:, :, 0]
301
+ hl = yh[:, :, 1]
302
+ hh = yh[:, :, 2]
303
+ lo = sfb1d(yl, lh, self.g0_col, self.g1_col, dim=2)
304
+ hi = sfb1d(hl, hh, self.g0_col, self.g1_col, dim=2)
305
+ yl = sfb1d(lo, hi, self.g0_row, self.g1_row, dim=3)
306
+ return yl
307
+
308
+
309
+ def remove_stripe_fw(data, sigma, wname, level):
310
+ """Remove stripes with wavelet filtering"""
311
+
312
+ [nproj, nz, ni] = data.shape
313
+
314
+ nproj_pad = nproj + nproj // 8
315
+ xshift = int((nproj_pad - nproj) // 2)
316
+
317
+ # Accepts all wave types available to PyWavelets
318
+ xfm = DWTForward(wave=wname)
319
+ ifm = DWTInverse(wave=wname)
320
+
321
+ # Wavelet decomposition.
322
+ cc = []
323
+ sli = cp.zeros([nz, 1, nproj_pad, ni], dtype='float32')
324
+
325
+ sli[:, 0, (nproj_pad - nproj)//2:(nproj_pad + nproj) //
326
+ 2] = data.astype('float32').swapaxes(0, 1)
327
+ for k in range(level):
328
+ sli, c = xfm.apply(sli)
329
+ cc.append(c)
330
+ # FFT
331
+ fcV = cp.fft.fft(cc[k][:, 0, 1], axis=1)
332
+ _, my, mx = fcV.shape
333
+ # Damping of ring artifact information.
334
+ y_hat = cp.fft.ifftshift((cp.arange(-my, my, 2) + 1) / 2)
335
+ damp = -cp.expm1(-y_hat**2 / (2 * sigma**2))
336
+ fcV *= cp.tile(damp, (mx, 1)).swapaxes(0, 1)
337
+ # Inverse FFT.
338
+ cc[k][:, 0, 1] = cp.fft.ifft(fcV, my, axis=1).real
339
+
340
+ # Wavelet reconstruction.
341
+ for k in range(level)[::-1]:
342
+ shape0 = cc[k][0, 0, 1].shape
343
+ sli = sli[:, :, :shape0[0], :shape0[1]]
344
+ sli = ifm.apply((sli, cc[k]))
345
+
346
+ data = sli[:, 0, (nproj_pad - nproj)//2:(nproj_pad + nproj) //
347
+ 2, :ni].astype(data.dtype) # modified
348
+ data = data.swapaxes(0, 1)
349
+
350
+ return data
351
+
352
+ ######## Titarenko ring removal ############################################################################################################################################################################
353
+ def remove_stripe_ti(data, beta, mask_size):
354
+ """Remove stripes with a new method by V. Titareno """
355
+ gamma = beta*((1-beta)/(1+beta)
356
+ )**cp.abs(cp.fft.fftfreq(data.shape[-1])*data.shape[-1])
357
+ gamma[0] -= 1
358
+ v = cp.mean(data, axis=0)
359
+ v = v-v[:, 0:1]
360
+ v = cp.fft.irfft(cp.fft.rfft(v)*cp.fft.rfft(gamma))
361
+ mask = cp.zeros(v.shape, dtype=v.dtype)
362
+ mask_size = mask_size*mask.shape[1]
363
+ mask[:, mask.shape[1]//2-mask_size//2:mask.shape[1]//2+mask_size//2] = 1
364
+ data[:] += v*mask
365
+ return data
366
+
367
+
368
+ ######## Optimized version for Vo-all ring removal in tomopy################################################################################################################################################################
369
+ def _rs_sort(sinogram, size, matindex, dim):
370
+ """
371
+ Remove stripes using the sorting technique.
372
+ """
373
+ sinogram = cp.transpose(sinogram)
374
+ matcomb = cp.asarray(cp.dstack((matindex, sinogram)))
375
+
376
+ # matsort = cp.asarray([row[row[:, 1].argsort()] for row in matcomb])
377
+ ids = cp.argsort(matcomb[:,:,1],axis=1)
378
+ matsort = matcomb.copy()
379
+ matsort[:,:,0] = cp.take_along_axis(matsort[:,:,0],ids,axis=1)
380
+ matsort[:,:,1] = cp.take_along_axis(matsort[:,:,1],ids,axis=1)
381
+ if dim == 1:
382
+ matsort[:, :, 1] = median_filter(matsort[:, :, 1], (size, 1))
383
+ else:
384
+ matsort[:, :, 1] = median_filter(matsort[:, :, 1], (size, size))
385
+
386
+ # matsortback = cp.asarray([row[row[:, 0].argsort()] for row in matsort])
387
+
388
+ ids = cp.argsort(matsort[:,:,0],axis=1)
389
+ matsortback = matsort.copy()
390
+ matsortback[:,:,0] = cp.take_along_axis(matsortback[:,:,0],ids,axis=1)
391
+ matsortback[:,:,1] = cp.take_along_axis(matsortback[:,:,1],ids,axis=1)
392
+
393
+ sino_corrected = matsortback[:, :, 1]
394
+ return cp.transpose(sino_corrected)
395
+
396
+ def _mpolyfit(x,y):
397
+ n= len(x)
398
+ x_mean = cp.mean(x)
399
+ y_mean = cp.mean(y)
400
+
401
+ Sxy = cp.sum(x*y) - n*x_mean*y_mean
402
+ Sxx = cp.sum(x*x) - n*x_mean*x_mean
403
+
404
+ slope = Sxy / Sxx
405
+ intercept = y_mean - slope*x_mean
406
+ return slope,intercept
407
+
408
+ def _detect_stripe(listdata, snr):
409
+ """
410
+ Algorithm 4 in :cite:`Vo:18`. Used to locate stripes.
411
+ """
412
+ numdata = len(listdata)
413
+ listsorted = cp.sort(listdata)[::-1]
414
+ xlist = cp.arange(0, numdata, 1.0)
415
+ ndrop = cp.int16(0.25 * numdata)
416
+ # (_slope, _intercept) = cp.polyfit(xlist[ndrop:-ndrop - 1],
417
+ # listsorted[ndrop:-ndrop - 1], 1)
418
+ (_slope, _intercept) = _mpolyfit(xlist[ndrop:-ndrop - 1], listsorted[ndrop:-ndrop - 1])
419
+
420
+ numt1 = _intercept + _slope * xlist[-1]
421
+ noiselevel = cp.abs(numt1 - _intercept)
422
+ noiselevel = cp.clip(noiselevel, 1e-6, None)
423
+ val1 = cp.abs(listsorted[0] - _intercept) / noiselevel
424
+ val2 = cp.abs(listsorted[-1] - numt1) / noiselevel
425
+ listmask = cp.zeros_like(listdata)
426
+ if (val1 >= snr):
427
+ upper_thresh = _intercept + noiselevel * snr * 0.5
428
+ listmask[listdata > upper_thresh] = 1.0
429
+ if (val2 >= snr):
430
+ lower_thresh = numt1 - noiselevel * snr * 0.5
431
+ listmask[listdata <= lower_thresh] = 1.0
432
+ return listmask
433
+
434
+ def _rs_large(sinogram, snr, size, matindex, drop_ratio=0.1, norm=True):
435
+ """
436
+ Remove large stripes.
437
+ """
438
+ drop_ratio = max(min(drop_ratio,0.8),0)# = cp.clip(drop_ratio, 0.0, 0.8)
439
+ (nrow, ncol) = sinogram.shape
440
+ ndrop = int(0.5 * drop_ratio * nrow)
441
+ sinosort = cp.sort(sinogram, axis=0)
442
+ sinosmooth = median_filter(sinosort, (1, size))
443
+ list1 = cp.mean(sinosort[ndrop:nrow - ndrop], axis=0)
444
+ list2 = cp.mean(sinosmooth[ndrop:nrow - ndrop], axis=0)
445
+ # listfact = cp.divide(list1,
446
+ # list2,
447
+ # out=cp.ones_like(list1),
448
+ # where=list2 != 0)
449
+
450
+ listfact = list1/list2
451
+
452
+ # Locate stripes
453
+ listmask = _detect_stripe(listfact, snr)
454
+ listmask = binary_dilation(listmask, iterations=1).astype(listmask.dtype)
455
+ matfact = cp.tile(listfact, (nrow, 1))
456
+ # Normalize
457
+ if norm is True:
458
+ sinogram = sinogram / matfact
459
+ sinogram1 = cp.transpose(sinogram)
460
+ matcombine = cp.asarray(cp.dstack((matindex, sinogram1)))
461
+
462
+ # matsort = cp.asarray([row[row[:, 1].argsort()] for row in matcombine])
463
+ ids = cp.argsort(matcombine[:,:,1],axis=1)
464
+ matsort = matcombine.copy()
465
+ matsort[:,:,0] = cp.take_along_axis(matsort[:,:,0],ids,axis=1)
466
+ matsort[:,:,1] = cp.take_along_axis(matsort[:,:,1],ids,axis=1)
467
+
468
+ matsort[:, :, 1] = cp.transpose(sinosmooth)
469
+ # matsortback = cp.asarray([row[row[:, 0].argsort()] for row in matsort])
470
+ ids = cp.argsort(matsort[:,:,0],axis=1)
471
+ matsortback = matsort.copy()
472
+ matsortback[:,:,0] = cp.take_along_axis(matsortback[:,:,0],ids,axis=1)
473
+ matsortback[:,:,1] = cp.take_along_axis(matsortback[:,:,1],ids,axis=1)
474
+
475
+ sino_corrected = cp.transpose(matsortback[:, :, 1])
476
+ listxmiss = cp.where(listmask > 0.0)[0]
477
+ sinogram[:, listxmiss] = sino_corrected[:, listxmiss]
478
+ return sinogram
479
+
480
+ def _rs_dead(sinogram, snr, size, matindex, norm=True):
481
+ """
482
+ Remove unresponsive and fluctuating stripes.
483
+ """
484
+ sinogram = cp.copy(sinogram) # Make it mutable
485
+ (nrow, _) = sinogram.shape
486
+ # sinosmooth = cp.apply_along_axis(uniform_filter1d, 0, sinogram, 10)
487
+ sinosmooth = uniform_filter1d(sinogram, 10, axis=0)
488
+
489
+ listdiff = cp.sum(cp.abs(sinogram - sinosmooth), axis=0)
490
+ listdiffbck = median_filter(listdiff, size)
491
+
492
+
493
+ listfact = listdiff/listdiffbck
494
+
495
+ listmask = _detect_stripe(listfact, snr)
496
+ listmask = binary_dilation(listmask, iterations=1).astype(listmask.dtype)
497
+ listmask[0:2] = 0.0
498
+ listmask[-2:] = 0.0
499
+ listx = cp.where(listmask < 1.0)[0]
500
+ listy = cp.arange(nrow)
501
+ matz = sinogram[:, listx]
502
+
503
+ listxmiss = cp.where(listmask > 0.0)[0]
504
+
505
+ # finter = interpolate.interp2d(listx.get(), listy.get(), matz.get(), kind='linear')
506
+ if len(listxmiss) > 0:
507
+ # sinogram_c[:, listxmiss.get()] = finter(listxmiss.get(), listy.get())
508
+ ids = cp.searchsorted(listx, listxmiss)
509
+ sinogram[:,listxmiss] = matz[:,ids-1]+(listxmiss-listx[ids-1])*(matz[:,ids]-matz[:,ids-1])/(listx[ids]-listx[ids-1])
510
+
511
+ # Remove residual stripes
512
+ if norm is True:
513
+ sinogram = _rs_large(sinogram, snr, size, matindex)
514
+ return sinogram
515
+
516
+
517
+ def _create_matindex(nrow, ncol):
518
+ """
519
+ Create a 2D array of indexes used for the sorting technique.
520
+ """
521
+ listindex = cp.arange(0.0, ncol, 1.0)
522
+ matindex = cp.tile(listindex, (nrow, 1))
523
+ return matindex
524
+
525
+ def remove_all_stripe(tomo,
526
+ snr=3,
527
+ la_size=61,
528
+ sm_size=21,
529
+ dim=1):
530
+ """
531
+ Remove all types of stripe artifacts from sinogram using Nghia Vo's
532
+ approach :cite:`Vo:18` (combination of algorithm 3,4,5, and 6).
533
+
534
+ Parameters
535
+ ----------
536
+ tomo : ndarray
537
+ 3D tomographic data.
538
+ snr : float
539
+ Ratio used to locate large stripes.
540
+ Greater is less sensitive.
541
+ la_size : int
542
+ Window size of the median filter to remove large stripes.
543
+ sm_size : int
544
+ Window size of the median filter to remove small-to-medium stripes.
545
+ dim : {1, 2}, optional
546
+ Dimension of the window.
547
+
548
+ Returns
549
+ -------
550
+ ndarray
551
+ Corrected 3D tomographic data.
552
+ """
553
+ matindex = _create_matindex(tomo.shape[2], tomo.shape[0])
554
+ for m in range(tomo.shape[1]):
555
+ sino = tomo[:, m, :]
556
+ sino = _rs_dead(sino, snr, la_size, matindex)
557
+ sino = _rs_sort(sino, sm_size, matindex, dim)
558
+ tomo[:, m, :] = sino
559
+ return tomo
560
+
561
+
562
+ from ..cuda.utils import pycuda_to_cupy
563
+ def remove_all_stripe_pycuda(radios, device_id=0, **kwargs):
564
+ """
565
+ Nabu interface to "remove_all_stripe". In-place!
566
+
567
+ Parameters
568
+ ----------
569
+ radios: pycuda.GPUArray
570
+ Stack of radios in the shape (n_angles, n_y, n_x)
571
+ so that sinogram number i is radios[:, i, :]
572
+
573
+ Other Parameters
574
+ ----------------
575
+ See parameters of 'remove_all_stripe
576
+ """
577
+
578
+ if getattr(remove_all_stripe, "_cupy_init", False) is False:
579
+ from cupy import cuda
580
+ cuda.Device(device_id).use()
581
+ setattr(remove_all_stripe, "_cupy_init", True)
582
+
583
+ cupy_radios = pycuda_to_cupy(radios) # no memory copy, the internal pointer is passed to pycuda
584
+ remove_all_stripe(cupy_radios, **kwargs)
585
+ return radios
586
+
nabu/utils.py CHANGED
@@ -1,13 +1,11 @@
1
1
  from functools import partial
2
2
  import os
3
3
  import sys
4
- from typing import Union
5
- import typing
4
+ from functools import partial, lru_cache
5
+ from itertools import product
6
6
  import warnings
7
7
  from time import time
8
8
  import posixpath
9
- from itertools import product
10
- from functools import lru_cache
11
9
  import numpy as np
12
10
 
13
11
 
@@ -424,6 +422,19 @@ def copy_dict_items(dict_, keys):
424
422
  return res
425
423
 
426
424
 
425
+ def remove_first_dict_items(dict_, n_items, sort_func=None, inplace=True):
426
+ """
427
+ Remove the first items of a dictionary. The keys have to be sortable
428
+ """
429
+ sorted_keys = sorted(dict_.keys(), key=sort_func)
430
+ if inplace:
431
+ for key in sorted_keys[:n_items]:
432
+ dict_.pop(key)
433
+ return dict_
434
+ else:
435
+ return copy_dict_items(dict_, sorted_keys[n_items:])
436
+
437
+
427
438
  def recursive_copy_dict(dict_):
428
439
  """
429
440
  Perform a shallow copy of a dictionary of dictionaries.
@@ -537,7 +548,7 @@ def filter_str_def(elmt):
537
548
  return elmt
538
549
 
539
550
 
540
- def convert_str_to_tuple(input_str: str, none_if_empty: bool = False) -> Union[None, tuple]:
551
+ def convert_str_to_tuple(input_str: str, none_if_empty: bool = False):
541
552
  """
542
553
  :param str input_str: string to convert
543
554
  :param bool none_if_empty: if true and the conversion is an empty tuple
@@ -572,7 +583,7 @@ class Progress:
572
583
  self._name = name
573
584
  self.reset()
574
585
 
575
- def reset(self, max_: typing.Union[None, int] = None) -> None:
586
+ def reset(self, max_=None):
576
587
  """
577
588
  reset the advancement to n and max advancement to max_
578
589
  :param int max_:
@@ -630,6 +641,19 @@ def concatenate_dict(dict_1, dict_2) -> dict:
630
641
  return res
631
642
 
632
643
 
644
+ class BaseClassError:
645
+ def __init__(self, *args, **kwargs):
646
+ raise ValueError("Base class")
647
+
648
+
649
+ def MissingComponentError(msg):
650
+ class MissingComponentCls:
651
+ def __init__(self, *args, **kwargs):
652
+ raise RuntimeError(msg)
653
+
654
+ return MissingComponentCls
655
+
656
+
633
657
  # ------------------------------------------------------------------------------
634
658
  # ------------------------ Image (move elsewhere ?) ----------------------------
635
659
  # ------------------------------------------------------------------------------
@@ -649,6 +673,8 @@ def generate_coords(img_shp, center=None):
649
673
 
650
674
  def clip_circle(img, center=None, radius=None):
651
675
  R, C = generate_coords(img.shape, center)
676
+ if radius is None:
677
+ radius = R.shape[-1] // 2
652
678
  M = R**2 + C**2
653
679
  res = np.zeros_like(img)
654
680
  res[M < radius**2] = img[M < radius**2]
@@ -1,6 +1,6 @@
1
1
  MIT License
2
2
 
3
- Copyright (c) 2020-2023 ESRF
3
+ Copyright (c) 2020-2024 ESRF
4
4
 
5
5
  Permission is hereby granted, free of charge, to any person obtaining a copy
6
6
  of this software and associated documentation files (the "Software"), to deal