nabu 2023.2.1__py3-none-any.whl → 2024.1.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doc/conf.py +1 -1
- doc/doc_config.py +32 -0
- nabu/__init__.py +2 -1
- nabu/app/bootstrap_stitching.py +1 -1
- nabu/app/cli_configs.py +122 -2
- nabu/app/composite_cor.py +27 -2
- nabu/app/correct_rot.py +70 -0
- nabu/app/create_distortion_map_from_poly.py +42 -18
- nabu/app/diag_to_pix.py +358 -0
- nabu/app/diag_to_rot.py +449 -0
- nabu/app/generate_header.py +4 -3
- nabu/app/histogram.py +2 -2
- nabu/app/multicor.py +6 -1
- nabu/app/parse_reconstruction_log.py +151 -0
- nabu/app/prepare_weights_double.py +83 -22
- nabu/app/reconstruct.py +5 -1
- nabu/app/reconstruct_helical.py +7 -0
- nabu/app/reduce_dark_flat.py +6 -3
- nabu/app/rotate.py +4 -4
- nabu/app/stitching.py +16 -2
- nabu/app/tests/test_reduce_dark_flat.py +18 -2
- nabu/app/validator.py +4 -4
- nabu/cuda/convolution.py +8 -376
- nabu/cuda/fft.py +4 -0
- nabu/cuda/kernel.py +4 -4
- nabu/cuda/medfilt.py +5 -158
- nabu/cuda/padding.py +5 -71
- nabu/cuda/processing.py +23 -2
- nabu/cuda/src/ElementOp.cu +78 -0
- nabu/cuda/src/backproj.cu +28 -2
- nabu/cuda/src/fourier_wavelets.cu +2 -2
- nabu/cuda/src/normalization.cu +23 -0
- nabu/cuda/src/padding.cu +2 -2
- nabu/cuda/src/transpose.cu +16 -0
- nabu/cuda/utils.py +39 -0
- nabu/estimation/alignment.py +10 -1
- nabu/estimation/cor.py +808 -38
- nabu/estimation/cor_sino.py +7 -9
- nabu/estimation/tests/test_cor.py +85 -3
- nabu/io/reader.py +26 -18
- nabu/io/tests/test_cast_volume.py +3 -3
- nabu/io/tests/test_detector_distortion.py +3 -3
- nabu/io/tiffwriter_zmm.py +2 -2
- nabu/io/utils.py +14 -4
- nabu/io/writer.py +5 -3
- nabu/misc/fftshift.py +6 -0
- nabu/misc/histogram.py +5 -285
- nabu/misc/histogram_cuda.py +8 -104
- nabu/misc/kernel_base.py +3 -121
- nabu/misc/padding_base.py +5 -69
- nabu/misc/processing_base.py +3 -107
- nabu/misc/rotation.py +5 -62
- nabu/misc/rotation_cuda.py +5 -65
- nabu/misc/transpose.py +6 -0
- nabu/misc/unsharp.py +3 -78
- nabu/misc/unsharp_cuda.py +5 -52
- nabu/misc/unsharp_opencl.py +8 -85
- nabu/opencl/fft.py +6 -0
- nabu/opencl/kernel.py +21 -6
- nabu/opencl/padding.py +5 -72
- nabu/opencl/processing.py +27 -5
- nabu/opencl/src/backproj.cl +3 -3
- nabu/opencl/src/fftshift.cl +65 -12
- nabu/opencl/src/padding.cl +2 -2
- nabu/opencl/src/roll.cl +96 -0
- nabu/opencl/src/transpose.cl +16 -0
- nabu/pipeline/config_validators.py +63 -3
- nabu/pipeline/dataset_validator.py +2 -2
- nabu/pipeline/estimators.py +193 -35
- nabu/pipeline/fullfield/chunked.py +34 -17
- nabu/pipeline/fullfield/chunked_cuda.py +7 -5
- nabu/pipeline/fullfield/computations.py +48 -13
- nabu/pipeline/fullfield/nabu_config.py +13 -13
- nabu/pipeline/fullfield/processconfig.py +10 -5
- nabu/pipeline/fullfield/reconstruction.py +1 -2
- nabu/pipeline/helical/fbp.py +5 -0
- nabu/pipeline/helical/filtering.py +12 -9
- nabu/pipeline/helical/gridded_accumulator.py +179 -33
- nabu/pipeline/helical/helical_chunked_regridded.py +262 -151
- nabu/pipeline/helical/helical_chunked_regridded_cuda.py +4 -11
- nabu/pipeline/helical/helical_reconstruction.py +56 -18
- nabu/pipeline/helical/span_strategy.py +1 -1
- nabu/pipeline/helical/tests/test_accumulator.py +4 -0
- nabu/pipeline/params.py +23 -2
- nabu/pipeline/processconfig.py +3 -8
- nabu/pipeline/tests/test_chunk_reader.py +78 -0
- nabu/pipeline/tests/test_estimators.py +120 -2
- nabu/pipeline/utils.py +25 -0
- nabu/pipeline/writer.py +2 -0
- nabu/preproc/ccd_cuda.py +9 -7
- nabu/preproc/ctf.py +21 -26
- nabu/preproc/ctf_cuda.py +25 -25
- nabu/preproc/double_flatfield.py +14 -2
- nabu/preproc/double_flatfield_cuda.py +7 -11
- nabu/preproc/flatfield_cuda.py +23 -27
- nabu/preproc/phase.py +19 -24
- nabu/preproc/phase_cuda.py +21 -21
- nabu/preproc/shift_cuda.py +58 -28
- nabu/preproc/tests/test_ctf.py +5 -5
- nabu/preproc/tests/test_double_flatfield.py +2 -2
- nabu/preproc/tests/test_vshift.py +13 -2
- nabu/processing/__init__.py +0 -0
- nabu/processing/convolution_cuda.py +375 -0
- nabu/processing/fft_base.py +163 -0
- nabu/processing/fft_cuda.py +256 -0
- nabu/processing/fft_opencl.py +54 -0
- nabu/processing/fftshift.py +134 -0
- nabu/processing/histogram.py +286 -0
- nabu/processing/histogram_cuda.py +103 -0
- nabu/processing/kernel_base.py +126 -0
- nabu/processing/medfilt_cuda.py +159 -0
- nabu/processing/muladd.py +29 -0
- nabu/processing/muladd_cuda.py +68 -0
- nabu/processing/padding_base.py +71 -0
- nabu/processing/padding_cuda.py +75 -0
- nabu/processing/padding_opencl.py +77 -0
- nabu/processing/processing_base.py +123 -0
- nabu/processing/roll_opencl.py +64 -0
- nabu/processing/rotation.py +63 -0
- nabu/processing/rotation_cuda.py +66 -0
- nabu/processing/tests/__init__.py +0 -0
- nabu/processing/tests/test_fft.py +268 -0
- nabu/processing/tests/test_fftshift.py +71 -0
- nabu/{misc → processing}/tests/test_histogram.py +2 -4
- nabu/{cuda → processing}/tests/test_medfilt.py +1 -1
- nabu/processing/tests/test_muladd.py +54 -0
- nabu/{cuda → processing}/tests/test_padding.py +119 -75
- nabu/processing/tests/test_roll.py +63 -0
- nabu/{misc → processing}/tests/test_rotation.py +3 -2
- nabu/processing/tests/test_transpose.py +72 -0
- nabu/{misc → processing}/tests/test_unsharp.py +41 -8
- nabu/processing/transpose.py +126 -0
- nabu/processing/unsharp.py +79 -0
- nabu/processing/unsharp_cuda.py +53 -0
- nabu/processing/unsharp_opencl.py +75 -0
- nabu/reconstruction/fbp.py +34 -10
- nabu/reconstruction/fbp_base.py +35 -16
- nabu/reconstruction/fbp_opencl.py +7 -12
- nabu/reconstruction/filtering.py +2 -2
- nabu/reconstruction/filtering_cuda.py +13 -14
- nabu/reconstruction/filtering_opencl.py +3 -4
- nabu/reconstruction/projection.py +2 -0
- nabu/reconstruction/rings.py +158 -1
- nabu/reconstruction/rings_cuda.py +218 -58
- nabu/reconstruction/sinogram_cuda.py +16 -12
- nabu/reconstruction/tests/test_deringer.py +116 -14
- nabu/reconstruction/tests/test_fbp.py +22 -31
- nabu/reconstruction/tests/test_filtering.py +11 -2
- nabu/resources/dataset_analyzer.py +89 -26
- nabu/resources/nxflatfield.py +2 -2
- nabu/resources/tests/test_nxflatfield.py +1 -1
- nabu/resources/utils.py +9 -2
- nabu/stitching/alignment.py +184 -0
- nabu/stitching/config.py +241 -39
- nabu/stitching/definitions.py +6 -0
- nabu/stitching/frame_composition.py +4 -2
- nabu/stitching/overlap.py +99 -3
- nabu/stitching/sample_normalization.py +60 -0
- nabu/stitching/slurm_utils.py +10 -10
- nabu/stitching/tests/test_alignment.py +99 -0
- nabu/stitching/tests/test_config.py +16 -1
- nabu/stitching/tests/test_overlap.py +68 -2
- nabu/stitching/tests/test_sample_normalization.py +49 -0
- nabu/stitching/tests/test_slurm_utils.py +5 -5
- nabu/stitching/tests/test_utils.py +3 -33
- nabu/stitching/tests/test_z_stitching.py +391 -22
- nabu/stitching/utils.py +144 -202
- nabu/stitching/z_stitching.py +309 -126
- nabu/testutils.py +18 -0
- nabu/thirdparty/tomocupy_remove_stripe.py +586 -0
- nabu/utils.py +32 -6
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/LICENSE +1 -1
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/METADATA +5 -5
- nabu-2024.1.0rc3.dist-info/RECORD +296 -0
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/WHEEL +1 -1
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/entry_points.txt +5 -1
- nabu/conftest.py +0 -14
- nabu/opencl/fftshift.py +0 -92
- nabu/opencl/tests/test_fftshift.py +0 -55
- nabu/opencl/tests/test_padding.py +0 -84
- nabu-2023.2.1.dist-info/RECORD +0 -252
- /nabu/cuda/src/{fftshift.cu → dfi_fftshift.cu} +0 -0
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/top_level.txt +0 -0
nabu/testutils.py
CHANGED
@@ -59,6 +59,24 @@ def get_data(*dataset_path):
|
|
59
59
|
return np.load(dataset_downloaded_path)
|
60
60
|
|
61
61
|
|
62
|
+
def get_array_of_given_shape(img, shape, dtype):
|
63
|
+
"""
|
64
|
+
From a given image, returns an array of the wanted shape and dtype.
|
65
|
+
"""
|
66
|
+
|
67
|
+
# Tile image until it's big enough.
|
68
|
+
# "fun" fact: using any(blabla) crashes but using any([blabla]) does not, because of variables re-evaluation
|
69
|
+
while any([i_dim <= s_dim for i_dim, s_dim in zip(img.shape, shape)]):
|
70
|
+
img = np.tile(img, (2, 2))
|
71
|
+
if len(shape) == 1:
|
72
|
+
arr = img[: shape[0], 0]
|
73
|
+
elif len(shape) == 2:
|
74
|
+
arr = img[: shape[0], : shape[1]]
|
75
|
+
else:
|
76
|
+
arr = np.tile(img, (shape[0], 1, 1))[: shape[0], : shape[1], : shape[2]]
|
77
|
+
return np.ascontiguousarray(np.squeeze(arr), dtype=dtype)
|
78
|
+
|
79
|
+
|
62
80
|
def get_big_data(filename):
|
63
81
|
if __big_testdata_dir__ is None:
|
64
82
|
return None
|
@@ -0,0 +1,586 @@
|
|
1
|
+
# pylint: skip-file
|
2
|
+
|
3
|
+
"""
|
4
|
+
This file is a "GPU" (through cupy) implementation of "remove_all_stripe".
|
5
|
+
The original method is implemented by Nghia Vo in the algotom project: https://github.com/algotom/algotom/blob/master/algotom/prep/removal.py
|
6
|
+
The implementation using cupy is done by Viktor Nikitin in the tomocupy project: https://github.com/tomography/tomocupy/blame/main/src/tomocupy/remove_stripe.py
|
7
|
+
License follows.
|
8
|
+
|
9
|
+
For now we can't rely on off-the-shelf tomocupy as it's not packaged in pypi, and compilation is quite tedious.
|
10
|
+
"""
|
11
|
+
|
12
|
+
# *************************************************************************** #
|
13
|
+
# Copyright © 2022, UChicago Argonne, LLC #
|
14
|
+
# All Rights Reserved #
|
15
|
+
# Software Name: Tomocupy #
|
16
|
+
# By: Argonne National Laboratory #
|
17
|
+
# #
|
18
|
+
# OPEN SOURCE LICENSE #
|
19
|
+
# #
|
20
|
+
# Redistribution and use in source and binary forms, with or without #
|
21
|
+
# modification, are permitted provided that the following conditions are met: #
|
22
|
+
# #
|
23
|
+
# 1. Redistributions of source code must retain the above copyright notice, #
|
24
|
+
# this list of conditions and the following disclaimer. #
|
25
|
+
# 2. Redistributions in binary form must reproduce the above copyright #
|
26
|
+
# notice, this list of conditions and the following disclaimer in the #
|
27
|
+
# documentation and/or other materials provided with the distribution. #
|
28
|
+
# 3. Neither the name of the copyright holder nor the names of its #
|
29
|
+
# contributors may be used to endorse or promote products derived #
|
30
|
+
# from this software without specific prior written permission. #
|
31
|
+
# #
|
32
|
+
# #
|
33
|
+
# *************************************************************************** #
|
34
|
+
# DISCLAIMER #
|
35
|
+
# #
|
36
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS #
|
37
|
+
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT #
|
38
|
+
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS #
|
39
|
+
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT #
|
40
|
+
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, #
|
41
|
+
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED #
|
42
|
+
# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR #
|
43
|
+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF #
|
44
|
+
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING #
|
45
|
+
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS #
|
46
|
+
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #
|
47
|
+
# *************************************************************************** #
|
48
|
+
|
49
|
+
try:
|
50
|
+
import cupy as cp
|
51
|
+
import pywt
|
52
|
+
from cupyx.scipy.ndimage import median_filter
|
53
|
+
from cupyx.scipy import signal
|
54
|
+
from cupyx.scipy.ndimage import binary_dilation
|
55
|
+
from cupyx.scipy.ndimage import uniform_filter1d
|
56
|
+
__have_tomocupy_deringer__ = True
|
57
|
+
except ImportError as err:
|
58
|
+
__have_tomocupy_deringer__ = False
|
59
|
+
__tomocupy_deringer_import_error__ = err
|
60
|
+
|
61
|
+
|
62
|
+
###### Ring removal with wavelet filtering (adapted for cupy from pytroch_wavelet package https://pytorch-wavelets.readthedocs.io/)################################################################################
|
63
|
+
|
64
|
+
def _reflect(x, minx, maxx):
|
65
|
+
"""Reflect the values in matrix *x* about the scalar values *minx* and
|
66
|
+
*maxx*. Hence a vector *x* containing a long linearly increasing series is
|
67
|
+
converted into a waveform which ramps linearly up and down between *minx*
|
68
|
+
and *maxx*. If *x* contains integers and *minx* and *maxx* are (integers +
|
69
|
+
0.5), the ramps will have repeated max and min samples.
|
70
|
+
|
71
|
+
.. codeauthor:: Rich Wareham <rjw57@cantab.net>, Aug 2013
|
72
|
+
.. codeauthor:: Nick Kingsbury, Cambridge University, January 1999.
|
73
|
+
|
74
|
+
"""
|
75
|
+
x = cp.asanyarray(x)
|
76
|
+
rng = maxx - minx
|
77
|
+
rng_by_2 = 2 * rng
|
78
|
+
mod = cp.fmod(x - minx, rng_by_2)
|
79
|
+
normed_mod = cp.where(mod < 0, mod + rng_by_2, mod)
|
80
|
+
out = cp.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx
|
81
|
+
return cp.array(out, dtype=x.dtype)
|
82
|
+
|
83
|
+
|
84
|
+
def _mypad(x, pad, value=0):
|
85
|
+
""" Function to do numpy like padding on Arrays. Only works for 2-D
|
86
|
+
padding.
|
87
|
+
|
88
|
+
Inputs:
|
89
|
+
x (array): Array to pad
|
90
|
+
pad (tuple): tuple of (left, right, top, bottom) pad sizes
|
91
|
+
"""
|
92
|
+
# Vertical only
|
93
|
+
if pad[0] == 0 and pad[1] == 0:
|
94
|
+
m1, m2 = pad[2], pad[3]
|
95
|
+
l = x.shape[-2]
|
96
|
+
xe = _reflect(cp.arange(-m1, l+m2, dtype='int32'), -0.5, l-0.5)
|
97
|
+
return x[:, :, xe]
|
98
|
+
# horizontal only
|
99
|
+
elif pad[2] == 0 and pad[3] == 0:
|
100
|
+
m1, m2 = pad[0], pad[1]
|
101
|
+
l = x.shape[-1]
|
102
|
+
xe = _reflect(cp.arange(-m1, l+m2, dtype='int32'), -0.5, l-0.5)
|
103
|
+
return x[:, :, :, xe]
|
104
|
+
|
105
|
+
|
106
|
+
def _conv2d(x, w, stride, pad, groups=1):
|
107
|
+
""" Convolution (equivalent pytorch.conv2d)
|
108
|
+
"""
|
109
|
+
if pad != 0:
|
110
|
+
x = cp.pad(x, ((0, 0), (0, 0), (pad, pad), (pad, pad)), 'constant')
|
111
|
+
|
112
|
+
b, ci, hi, wi = x.shape
|
113
|
+
co, _, hk, wk = w.shape
|
114
|
+
ho = int(cp.floor(1 + (hi - hk) / stride[0]))
|
115
|
+
wo = int(cp.floor(1 + (wi - wk) / stride[1]))
|
116
|
+
out = cp.zeros([b, co, ho, wo], dtype='float32')
|
117
|
+
x = cp.expand_dims(x, axis=1)
|
118
|
+
w = cp.expand_dims(w, axis=0)
|
119
|
+
chunk = ci//groups
|
120
|
+
chunko = co//groups
|
121
|
+
for g in range(groups):
|
122
|
+
for ii in range(hk):
|
123
|
+
for jj in range(wk):
|
124
|
+
x_windows = x[:, :, g*chunk:(g+1)*chunk, ii:ho *
|
125
|
+
stride[0]+ii:stride[0], jj:wo*stride[1]+jj:stride[1]]
|
126
|
+
out[:, g*chunko:(g+1)*chunko] += cp.sum(x_windows *
|
127
|
+
w[:, g*chunko:(g+1)*chunko, :, ii:ii+1, jj:jj+1], axis=2)
|
128
|
+
return out
|
129
|
+
|
130
|
+
|
131
|
+
def _conv_transpose2d(x, w, stride, pad, bias=None, groups=1):
|
132
|
+
""" Transposed convolution (equivalent pytorch.conv_transpose2d)
|
133
|
+
"""
|
134
|
+
b, co, ho, wo = x.shape
|
135
|
+
co, ci, hk, wk = w.shape
|
136
|
+
|
137
|
+
hi = (ho-1)*stride[0]+hk
|
138
|
+
wi = (wo-1)*stride[1]+wk
|
139
|
+
out = cp.zeros([b, ci, hi, wi], dtype='float32')
|
140
|
+
chunk = ci//groups
|
141
|
+
chunko = co//groups
|
142
|
+
for g in range(groups):
|
143
|
+
for ii in range(hk):
|
144
|
+
for jj in range(wk):
|
145
|
+
x_windows = x[:, g*chunko:(g+1)*chunko]
|
146
|
+
out[:, g*chunk:(g+1)*chunk, ii:ho*stride[0]+ii:stride[0], jj:wo*stride[1] +
|
147
|
+
jj:stride[1]] += x_windows * w[g*chunko:(g+1)*chunko, :, ii:ii+1, jj:jj+1]
|
148
|
+
if pad != 0:
|
149
|
+
out = out[:, :, pad[0]:out.shape[2]-pad[0], pad[1]:out.shape[3]-pad[1]]
|
150
|
+
return out
|
151
|
+
|
152
|
+
|
153
|
+
def afb1d(x, h0, h1='zero', dim=-1):
|
154
|
+
""" 1D analysis filter bank (along one dimension only) of an image
|
155
|
+
|
156
|
+
Parameters
|
157
|
+
----------
|
158
|
+
x (array): 4D input with the last two dimensions the spatial input
|
159
|
+
h0 (array): 4D input for the lowpass filter. Should have shape (1, 1,
|
160
|
+
h, 1) or (1, 1, 1, w)
|
161
|
+
h1 (array): 4D input for the highpass filter. Should have shape (1, 1,
|
162
|
+
h, 1) or (1, 1, 1, w)
|
163
|
+
dim (int) - dimension of filtering. d=2 is for a vertical filter (called
|
164
|
+
column filtering but filters across the rows). d=3 is for a
|
165
|
+
horizontal filter, (called row filtering but filters across the
|
166
|
+
columns).
|
167
|
+
|
168
|
+
Returns
|
169
|
+
-------
|
170
|
+
lohi: lowpass and highpass subbands concatenated along the channel
|
171
|
+
dimension
|
172
|
+
"""
|
173
|
+
C = x.shape[1]
|
174
|
+
# Convert the dim to positive
|
175
|
+
d = dim % 4
|
176
|
+
s = (2, 1) if d == 2 else (1, 2)
|
177
|
+
N = x.shape[d]
|
178
|
+
L = h0.size
|
179
|
+
L2 = L // 2
|
180
|
+
shape = [1, 1, 1, 1]
|
181
|
+
shape[d] = L
|
182
|
+
h = cp.concatenate([h0.reshape(*shape), h1.reshape(*shape)]*C, axis=0)
|
183
|
+
# Calculate the pad size
|
184
|
+
outsize = pywt.dwt_coeff_len(N, L, mode='symmetric')
|
185
|
+
p = 2 * (outsize - 1) - N + L
|
186
|
+
pad = (0, 0, p//2, (p+1)//2) if d == 2 else (p//2, (p+1)//2, 0, 0)
|
187
|
+
x = _mypad(x, pad=pad)
|
188
|
+
lohi = _conv2d(x, h, stride=s, pad=0, groups=C)
|
189
|
+
return lohi
|
190
|
+
|
191
|
+
|
192
|
+
def sfb1d(lo, hi, g0, g1='zero', dim=-1):
|
193
|
+
""" 1D synthesis filter bank of an image Array
|
194
|
+
"""
|
195
|
+
|
196
|
+
C = lo.shape[1]
|
197
|
+
d = dim % 4
|
198
|
+
L = g0.size
|
199
|
+
shape = [1, 1, 1, 1]
|
200
|
+
shape[d] = L
|
201
|
+
N = 2*lo.shape[d]
|
202
|
+
s = (2, 1) if d == 2 else (1, 2)
|
203
|
+
g0 = cp.concatenate([g0.reshape(*shape)]*C, axis=0)
|
204
|
+
g1 = cp.concatenate([g1.reshape(*shape)]*C, axis=0)
|
205
|
+
pad = (L-2, 0) if d == 2 else (0, L-2)
|
206
|
+
y = _conv_transpose2d(cp.asarray(lo), cp.asarray(g0), stride=s, pad=pad, groups=C) + \
|
207
|
+
_conv_transpose2d(cp.asarray(hi), cp.asarray(g1),
|
208
|
+
stride=s, pad=pad, groups=C)
|
209
|
+
return y
|
210
|
+
|
211
|
+
|
212
|
+
class DWTForward():
|
213
|
+
""" Performs a 2d DWT Forward decomposition of an image
|
214
|
+
|
215
|
+
Args:
|
216
|
+
wave (str): Which wavelet to use.
|
217
|
+
"""
|
218
|
+
|
219
|
+
def __init__(self, wave='db1'):
|
220
|
+
super().__init__()
|
221
|
+
|
222
|
+
wave = pywt.Wavelet(wave)
|
223
|
+
h0_col, h1_col = wave.dec_lo, wave.dec_hi
|
224
|
+
h0_row, h1_row = h0_col, h1_col
|
225
|
+
|
226
|
+
self.h0_col = cp.array(h0_col).astype('float32')[
|
227
|
+
::-1].reshape((1, 1, -1, 1))
|
228
|
+
self.h1_col = cp.array(h1_col).astype('float32')[
|
229
|
+
::-1].reshape((1, 1, -1, 1))
|
230
|
+
self.h0_row = cp.array(h0_row).astype('float32')[
|
231
|
+
::-1].reshape((1, 1, 1, -1))
|
232
|
+
self.h1_row = cp.array(h1_row).astype('float32')[
|
233
|
+
::-1].reshape((1, 1, 1, -1))
|
234
|
+
|
235
|
+
def apply(self, x):
|
236
|
+
""" Forward pass of the DWT.
|
237
|
+
|
238
|
+
Args:
|
239
|
+
x (array): Input of shape :math:`(N, C_{in}, H_{in}, W_{in})`
|
240
|
+
|
241
|
+
Returns:
|
242
|
+
(yl, yh)
|
243
|
+
tuple of lowpass (yl) and bandpass (yh) coefficients.
|
244
|
+
yh is a list of scale coefficients. yl has shape
|
245
|
+
:math:`(N, C_{in}, H_{in}', W_{in}')` and yh has shape
|
246
|
+
:math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. The new
|
247
|
+
dimension in yh iterates over the LH, HL and HH coefficients.
|
248
|
+
|
249
|
+
Note:
|
250
|
+
:math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly
|
251
|
+
downsampled shapes of the DWT pyramid.
|
252
|
+
"""
|
253
|
+
# Do a multilevel transform
|
254
|
+
# Do 1 level of the transform
|
255
|
+
lohi = afb1d(x, self.h0_row, self.h1_row, dim=3)
|
256
|
+
y = afb1d(lohi, self.h0_col, self.h1_col, dim=2)
|
257
|
+
s = y.shape
|
258
|
+
y = y.reshape(s[0], -1, 4, s[-2], s[-1]) # pylint: disable=E1121 # this might blow up in the future
|
259
|
+
x = cp.ascontiguousarray(y[:, :, 0])
|
260
|
+
yh = cp.ascontiguousarray(y[:, :, 1:])
|
261
|
+
return x, yh
|
262
|
+
|
263
|
+
|
264
|
+
class DWTInverse():
|
265
|
+
""" Performs a 2d DWT Inverse reconstruction of an image
|
266
|
+
|
267
|
+
Args:
|
268
|
+
wave (str): Which wavelet to use.
|
269
|
+
"""
|
270
|
+
|
271
|
+
def __init__(self, wave='db1'):
|
272
|
+
super().__init__()
|
273
|
+
wave = pywt.Wavelet(wave)
|
274
|
+
g0_col, g1_col = wave.rec_lo, wave.rec_hi
|
275
|
+
g0_row, g1_row = g0_col, g1_col
|
276
|
+
# Prepare the filters
|
277
|
+
self.g0_col = cp.array(g0_col).astype('float32').reshape((1, 1, -1, 1))
|
278
|
+
self.g1_col = cp.array(g1_col).astype('float32').reshape((1, 1, -1, 1))
|
279
|
+
self.g0_row = cp.array(g0_row).astype('float32').reshape((1, 1, 1, -1))
|
280
|
+
self.g1_row = cp.array(g1_row).astype('float32').reshape((1, 1, 1, -1))
|
281
|
+
|
282
|
+
def apply(self, coeffs):
|
283
|
+
"""
|
284
|
+
Args:
|
285
|
+
coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where:
|
286
|
+
yl is a lowpass array of shape :math:`(N, C_{in}, H_{in}',
|
287
|
+
W_{in}')` and yh is a list of bandpass arrays of shape
|
288
|
+
:math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match
|
289
|
+
the format returned by DWTForward
|
290
|
+
|
291
|
+
Returns:
|
292
|
+
Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})`
|
293
|
+
|
294
|
+
Note:
|
295
|
+
:math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly
|
296
|
+
downsampled shapes of the DWT pyramid.
|
297
|
+
|
298
|
+
"""
|
299
|
+
yl, yh = coeffs
|
300
|
+
lh = yh[:, :, 0]
|
301
|
+
hl = yh[:, :, 1]
|
302
|
+
hh = yh[:, :, 2]
|
303
|
+
lo = sfb1d(yl, lh, self.g0_col, self.g1_col, dim=2)
|
304
|
+
hi = sfb1d(hl, hh, self.g0_col, self.g1_col, dim=2)
|
305
|
+
yl = sfb1d(lo, hi, self.g0_row, self.g1_row, dim=3)
|
306
|
+
return yl
|
307
|
+
|
308
|
+
|
309
|
+
def remove_stripe_fw(data, sigma, wname, level):
|
310
|
+
"""Remove stripes with wavelet filtering"""
|
311
|
+
|
312
|
+
[nproj, nz, ni] = data.shape
|
313
|
+
|
314
|
+
nproj_pad = nproj + nproj // 8
|
315
|
+
xshift = int((nproj_pad - nproj) // 2)
|
316
|
+
|
317
|
+
# Accepts all wave types available to PyWavelets
|
318
|
+
xfm = DWTForward(wave=wname)
|
319
|
+
ifm = DWTInverse(wave=wname)
|
320
|
+
|
321
|
+
# Wavelet decomposition.
|
322
|
+
cc = []
|
323
|
+
sli = cp.zeros([nz, 1, nproj_pad, ni], dtype='float32')
|
324
|
+
|
325
|
+
sli[:, 0, (nproj_pad - nproj)//2:(nproj_pad + nproj) //
|
326
|
+
2] = data.astype('float32').swapaxes(0, 1)
|
327
|
+
for k in range(level):
|
328
|
+
sli, c = xfm.apply(sli)
|
329
|
+
cc.append(c)
|
330
|
+
# FFT
|
331
|
+
fcV = cp.fft.fft(cc[k][:, 0, 1], axis=1)
|
332
|
+
_, my, mx = fcV.shape
|
333
|
+
# Damping of ring artifact information.
|
334
|
+
y_hat = cp.fft.ifftshift((cp.arange(-my, my, 2) + 1) / 2)
|
335
|
+
damp = -cp.expm1(-y_hat**2 / (2 * sigma**2))
|
336
|
+
fcV *= cp.tile(damp, (mx, 1)).swapaxes(0, 1)
|
337
|
+
# Inverse FFT.
|
338
|
+
cc[k][:, 0, 1] = cp.fft.ifft(fcV, my, axis=1).real
|
339
|
+
|
340
|
+
# Wavelet reconstruction.
|
341
|
+
for k in range(level)[::-1]:
|
342
|
+
shape0 = cc[k][0, 0, 1].shape
|
343
|
+
sli = sli[:, :, :shape0[0], :shape0[1]]
|
344
|
+
sli = ifm.apply((sli, cc[k]))
|
345
|
+
|
346
|
+
data = sli[:, 0, (nproj_pad - nproj)//2:(nproj_pad + nproj) //
|
347
|
+
2, :ni].astype(data.dtype) # modified
|
348
|
+
data = data.swapaxes(0, 1)
|
349
|
+
|
350
|
+
return data
|
351
|
+
|
352
|
+
######## Titarenko ring removal ############################################################################################################################################################################
|
353
|
+
def remove_stripe_ti(data, beta, mask_size):
|
354
|
+
"""Remove stripes with a new method by V. Titareno """
|
355
|
+
gamma = beta*((1-beta)/(1+beta)
|
356
|
+
)**cp.abs(cp.fft.fftfreq(data.shape[-1])*data.shape[-1])
|
357
|
+
gamma[0] -= 1
|
358
|
+
v = cp.mean(data, axis=0)
|
359
|
+
v = v-v[:, 0:1]
|
360
|
+
v = cp.fft.irfft(cp.fft.rfft(v)*cp.fft.rfft(gamma))
|
361
|
+
mask = cp.zeros(v.shape, dtype=v.dtype)
|
362
|
+
mask_size = mask_size*mask.shape[1]
|
363
|
+
mask[:, mask.shape[1]//2-mask_size//2:mask.shape[1]//2+mask_size//2] = 1
|
364
|
+
data[:] += v*mask
|
365
|
+
return data
|
366
|
+
|
367
|
+
|
368
|
+
######## Optimized version for Vo-all ring removal in tomopy################################################################################################################################################################
|
369
|
+
def _rs_sort(sinogram, size, matindex, dim):
|
370
|
+
"""
|
371
|
+
Remove stripes using the sorting technique.
|
372
|
+
"""
|
373
|
+
sinogram = cp.transpose(sinogram)
|
374
|
+
matcomb = cp.asarray(cp.dstack((matindex, sinogram)))
|
375
|
+
|
376
|
+
# matsort = cp.asarray([row[row[:, 1].argsort()] for row in matcomb])
|
377
|
+
ids = cp.argsort(matcomb[:,:,1],axis=1)
|
378
|
+
matsort = matcomb.copy()
|
379
|
+
matsort[:,:,0] = cp.take_along_axis(matsort[:,:,0],ids,axis=1)
|
380
|
+
matsort[:,:,1] = cp.take_along_axis(matsort[:,:,1],ids,axis=1)
|
381
|
+
if dim == 1:
|
382
|
+
matsort[:, :, 1] = median_filter(matsort[:, :, 1], (size, 1))
|
383
|
+
else:
|
384
|
+
matsort[:, :, 1] = median_filter(matsort[:, :, 1], (size, size))
|
385
|
+
|
386
|
+
# matsortback = cp.asarray([row[row[:, 0].argsort()] for row in matsort])
|
387
|
+
|
388
|
+
ids = cp.argsort(matsort[:,:,0],axis=1)
|
389
|
+
matsortback = matsort.copy()
|
390
|
+
matsortback[:,:,0] = cp.take_along_axis(matsortback[:,:,0],ids,axis=1)
|
391
|
+
matsortback[:,:,1] = cp.take_along_axis(matsortback[:,:,1],ids,axis=1)
|
392
|
+
|
393
|
+
sino_corrected = matsortback[:, :, 1]
|
394
|
+
return cp.transpose(sino_corrected)
|
395
|
+
|
396
|
+
def _mpolyfit(x,y):
|
397
|
+
n= len(x)
|
398
|
+
x_mean = cp.mean(x)
|
399
|
+
y_mean = cp.mean(y)
|
400
|
+
|
401
|
+
Sxy = cp.sum(x*y) - n*x_mean*y_mean
|
402
|
+
Sxx = cp.sum(x*x) - n*x_mean*x_mean
|
403
|
+
|
404
|
+
slope = Sxy / Sxx
|
405
|
+
intercept = y_mean - slope*x_mean
|
406
|
+
return slope,intercept
|
407
|
+
|
408
|
+
def _detect_stripe(listdata, snr):
|
409
|
+
"""
|
410
|
+
Algorithm 4 in :cite:`Vo:18`. Used to locate stripes.
|
411
|
+
"""
|
412
|
+
numdata = len(listdata)
|
413
|
+
listsorted = cp.sort(listdata)[::-1]
|
414
|
+
xlist = cp.arange(0, numdata, 1.0)
|
415
|
+
ndrop = cp.int16(0.25 * numdata)
|
416
|
+
# (_slope, _intercept) = cp.polyfit(xlist[ndrop:-ndrop - 1],
|
417
|
+
# listsorted[ndrop:-ndrop - 1], 1)
|
418
|
+
(_slope, _intercept) = _mpolyfit(xlist[ndrop:-ndrop - 1], listsorted[ndrop:-ndrop - 1])
|
419
|
+
|
420
|
+
numt1 = _intercept + _slope * xlist[-1]
|
421
|
+
noiselevel = cp.abs(numt1 - _intercept)
|
422
|
+
noiselevel = cp.clip(noiselevel, 1e-6, None)
|
423
|
+
val1 = cp.abs(listsorted[0] - _intercept) / noiselevel
|
424
|
+
val2 = cp.abs(listsorted[-1] - numt1) / noiselevel
|
425
|
+
listmask = cp.zeros_like(listdata)
|
426
|
+
if (val1 >= snr):
|
427
|
+
upper_thresh = _intercept + noiselevel * snr * 0.5
|
428
|
+
listmask[listdata > upper_thresh] = 1.0
|
429
|
+
if (val2 >= snr):
|
430
|
+
lower_thresh = numt1 - noiselevel * snr * 0.5
|
431
|
+
listmask[listdata <= lower_thresh] = 1.0
|
432
|
+
return listmask
|
433
|
+
|
434
|
+
def _rs_large(sinogram, snr, size, matindex, drop_ratio=0.1, norm=True):
|
435
|
+
"""
|
436
|
+
Remove large stripes.
|
437
|
+
"""
|
438
|
+
drop_ratio = max(min(drop_ratio,0.8),0)# = cp.clip(drop_ratio, 0.0, 0.8)
|
439
|
+
(nrow, ncol) = sinogram.shape
|
440
|
+
ndrop = int(0.5 * drop_ratio * nrow)
|
441
|
+
sinosort = cp.sort(sinogram, axis=0)
|
442
|
+
sinosmooth = median_filter(sinosort, (1, size))
|
443
|
+
list1 = cp.mean(sinosort[ndrop:nrow - ndrop], axis=0)
|
444
|
+
list2 = cp.mean(sinosmooth[ndrop:nrow - ndrop], axis=0)
|
445
|
+
# listfact = cp.divide(list1,
|
446
|
+
# list2,
|
447
|
+
# out=cp.ones_like(list1),
|
448
|
+
# where=list2 != 0)
|
449
|
+
|
450
|
+
listfact = list1/list2
|
451
|
+
|
452
|
+
# Locate stripes
|
453
|
+
listmask = _detect_stripe(listfact, snr)
|
454
|
+
listmask = binary_dilation(listmask, iterations=1).astype(listmask.dtype)
|
455
|
+
matfact = cp.tile(listfact, (nrow, 1))
|
456
|
+
# Normalize
|
457
|
+
if norm is True:
|
458
|
+
sinogram = sinogram / matfact
|
459
|
+
sinogram1 = cp.transpose(sinogram)
|
460
|
+
matcombine = cp.asarray(cp.dstack((matindex, sinogram1)))
|
461
|
+
|
462
|
+
# matsort = cp.asarray([row[row[:, 1].argsort()] for row in matcombine])
|
463
|
+
ids = cp.argsort(matcombine[:,:,1],axis=1)
|
464
|
+
matsort = matcombine.copy()
|
465
|
+
matsort[:,:,0] = cp.take_along_axis(matsort[:,:,0],ids,axis=1)
|
466
|
+
matsort[:,:,1] = cp.take_along_axis(matsort[:,:,1],ids,axis=1)
|
467
|
+
|
468
|
+
matsort[:, :, 1] = cp.transpose(sinosmooth)
|
469
|
+
# matsortback = cp.asarray([row[row[:, 0].argsort()] for row in matsort])
|
470
|
+
ids = cp.argsort(matsort[:,:,0],axis=1)
|
471
|
+
matsortback = matsort.copy()
|
472
|
+
matsortback[:,:,0] = cp.take_along_axis(matsortback[:,:,0],ids,axis=1)
|
473
|
+
matsortback[:,:,1] = cp.take_along_axis(matsortback[:,:,1],ids,axis=1)
|
474
|
+
|
475
|
+
sino_corrected = cp.transpose(matsortback[:, :, 1])
|
476
|
+
listxmiss = cp.where(listmask > 0.0)[0]
|
477
|
+
sinogram[:, listxmiss] = sino_corrected[:, listxmiss]
|
478
|
+
return sinogram
|
479
|
+
|
480
|
+
def _rs_dead(sinogram, snr, size, matindex, norm=True):
|
481
|
+
"""
|
482
|
+
Remove unresponsive and fluctuating stripes.
|
483
|
+
"""
|
484
|
+
sinogram = cp.copy(sinogram) # Make it mutable
|
485
|
+
(nrow, _) = sinogram.shape
|
486
|
+
# sinosmooth = cp.apply_along_axis(uniform_filter1d, 0, sinogram, 10)
|
487
|
+
sinosmooth = uniform_filter1d(sinogram, 10, axis=0)
|
488
|
+
|
489
|
+
listdiff = cp.sum(cp.abs(sinogram - sinosmooth), axis=0)
|
490
|
+
listdiffbck = median_filter(listdiff, size)
|
491
|
+
|
492
|
+
|
493
|
+
listfact = listdiff/listdiffbck
|
494
|
+
|
495
|
+
listmask = _detect_stripe(listfact, snr)
|
496
|
+
listmask = binary_dilation(listmask, iterations=1).astype(listmask.dtype)
|
497
|
+
listmask[0:2] = 0.0
|
498
|
+
listmask[-2:] = 0.0
|
499
|
+
listx = cp.where(listmask < 1.0)[0]
|
500
|
+
listy = cp.arange(nrow)
|
501
|
+
matz = sinogram[:, listx]
|
502
|
+
|
503
|
+
listxmiss = cp.where(listmask > 0.0)[0]
|
504
|
+
|
505
|
+
# finter = interpolate.interp2d(listx.get(), listy.get(), matz.get(), kind='linear')
|
506
|
+
if len(listxmiss) > 0:
|
507
|
+
# sinogram_c[:, listxmiss.get()] = finter(listxmiss.get(), listy.get())
|
508
|
+
ids = cp.searchsorted(listx, listxmiss)
|
509
|
+
sinogram[:,listxmiss] = matz[:,ids-1]+(listxmiss-listx[ids-1])*(matz[:,ids]-matz[:,ids-1])/(listx[ids]-listx[ids-1])
|
510
|
+
|
511
|
+
# Remove residual stripes
|
512
|
+
if norm is True:
|
513
|
+
sinogram = _rs_large(sinogram, snr, size, matindex)
|
514
|
+
return sinogram
|
515
|
+
|
516
|
+
|
517
|
+
def _create_matindex(nrow, ncol):
|
518
|
+
"""
|
519
|
+
Create a 2D array of indexes used for the sorting technique.
|
520
|
+
"""
|
521
|
+
listindex = cp.arange(0.0, ncol, 1.0)
|
522
|
+
matindex = cp.tile(listindex, (nrow, 1))
|
523
|
+
return matindex
|
524
|
+
|
525
|
+
def remove_all_stripe(tomo,
|
526
|
+
snr=3,
|
527
|
+
la_size=61,
|
528
|
+
sm_size=21,
|
529
|
+
dim=1):
|
530
|
+
"""
|
531
|
+
Remove all types of stripe artifacts from sinogram using Nghia Vo's
|
532
|
+
approach :cite:`Vo:18` (combination of algorithm 3,4,5, and 6).
|
533
|
+
|
534
|
+
Parameters
|
535
|
+
----------
|
536
|
+
tomo : ndarray
|
537
|
+
3D tomographic data.
|
538
|
+
snr : float
|
539
|
+
Ratio used to locate large stripes.
|
540
|
+
Greater is less sensitive.
|
541
|
+
la_size : int
|
542
|
+
Window size of the median filter to remove large stripes.
|
543
|
+
sm_size : int
|
544
|
+
Window size of the median filter to remove small-to-medium stripes.
|
545
|
+
dim : {1, 2}, optional
|
546
|
+
Dimension of the window.
|
547
|
+
|
548
|
+
Returns
|
549
|
+
-------
|
550
|
+
ndarray
|
551
|
+
Corrected 3D tomographic data.
|
552
|
+
"""
|
553
|
+
matindex = _create_matindex(tomo.shape[2], tomo.shape[0])
|
554
|
+
for m in range(tomo.shape[1]):
|
555
|
+
sino = tomo[:, m, :]
|
556
|
+
sino = _rs_dead(sino, snr, la_size, matindex)
|
557
|
+
sino = _rs_sort(sino, sm_size, matindex, dim)
|
558
|
+
tomo[:, m, :] = sino
|
559
|
+
return tomo
|
560
|
+
|
561
|
+
|
562
|
+
from ..cuda.utils import pycuda_to_cupy
|
563
|
+
def remove_all_stripe_pycuda(radios, device_id=0, **kwargs):
|
564
|
+
"""
|
565
|
+
Nabu interface to "remove_all_stripe". In-place!
|
566
|
+
|
567
|
+
Parameters
|
568
|
+
----------
|
569
|
+
radios: pycuda.GPUArray
|
570
|
+
Stack of radios in the shape (n_angles, n_y, n_x)
|
571
|
+
so that sinogram number i is radios[:, i, :]
|
572
|
+
|
573
|
+
Other Parameters
|
574
|
+
----------------
|
575
|
+
See parameters of 'remove_all_stripe
|
576
|
+
"""
|
577
|
+
|
578
|
+
if getattr(remove_all_stripe, "_cupy_init", False) is False:
|
579
|
+
from cupy import cuda
|
580
|
+
cuda.Device(device_id).use()
|
581
|
+
setattr(remove_all_stripe, "_cupy_init", True)
|
582
|
+
|
583
|
+
cupy_radios = pycuda_to_cupy(radios) # no memory copy, the internal pointer is passed to pycuda
|
584
|
+
remove_all_stripe(cupy_radios, **kwargs)
|
585
|
+
return radios
|
586
|
+
|
nabu/utils.py
CHANGED
@@ -1,13 +1,11 @@
|
|
1
1
|
from functools import partial
|
2
2
|
import os
|
3
3
|
import sys
|
4
|
-
from
|
5
|
-
import
|
4
|
+
from functools import partial, lru_cache
|
5
|
+
from itertools import product
|
6
6
|
import warnings
|
7
7
|
from time import time
|
8
8
|
import posixpath
|
9
|
-
from itertools import product
|
10
|
-
from functools import lru_cache
|
11
9
|
import numpy as np
|
12
10
|
|
13
11
|
|
@@ -424,6 +422,19 @@ def copy_dict_items(dict_, keys):
|
|
424
422
|
return res
|
425
423
|
|
426
424
|
|
425
|
+
def remove_first_dict_items(dict_, n_items, sort_func=None, inplace=True):
|
426
|
+
"""
|
427
|
+
Remove the first items of a dictionary. The keys have to be sortable
|
428
|
+
"""
|
429
|
+
sorted_keys = sorted(dict_.keys(), key=sort_func)
|
430
|
+
if inplace:
|
431
|
+
for key in sorted_keys[:n_items]:
|
432
|
+
dict_.pop(key)
|
433
|
+
return dict_
|
434
|
+
else:
|
435
|
+
return copy_dict_items(dict_, sorted_keys[n_items:])
|
436
|
+
|
437
|
+
|
427
438
|
def recursive_copy_dict(dict_):
|
428
439
|
"""
|
429
440
|
Perform a shallow copy of a dictionary of dictionaries.
|
@@ -537,7 +548,7 @@ def filter_str_def(elmt):
|
|
537
548
|
return elmt
|
538
549
|
|
539
550
|
|
540
|
-
def convert_str_to_tuple(input_str: str, none_if_empty: bool = False)
|
551
|
+
def convert_str_to_tuple(input_str: str, none_if_empty: bool = False):
|
541
552
|
"""
|
542
553
|
:param str input_str: string to convert
|
543
554
|
:param bool none_if_empty: if true and the conversion is an empty tuple
|
@@ -572,7 +583,7 @@ class Progress:
|
|
572
583
|
self._name = name
|
573
584
|
self.reset()
|
574
585
|
|
575
|
-
def reset(self, max_
|
586
|
+
def reset(self, max_=None):
|
576
587
|
"""
|
577
588
|
reset the advancement to n and max advancement to max_
|
578
589
|
:param int max_:
|
@@ -630,6 +641,19 @@ def concatenate_dict(dict_1, dict_2) -> dict:
|
|
630
641
|
return res
|
631
642
|
|
632
643
|
|
644
|
+
class BaseClassError:
|
645
|
+
def __init__(self, *args, **kwargs):
|
646
|
+
raise ValueError("Base class")
|
647
|
+
|
648
|
+
|
649
|
+
def MissingComponentError(msg):
|
650
|
+
class MissingComponentCls:
|
651
|
+
def __init__(self, *args, **kwargs):
|
652
|
+
raise RuntimeError(msg)
|
653
|
+
|
654
|
+
return MissingComponentCls
|
655
|
+
|
656
|
+
|
633
657
|
# ------------------------------------------------------------------------------
|
634
658
|
# ------------------------ Image (move elsewhere ?) ----------------------------
|
635
659
|
# ------------------------------------------------------------------------------
|
@@ -649,6 +673,8 @@ def generate_coords(img_shp, center=None):
|
|
649
673
|
|
650
674
|
def clip_circle(img, center=None, radius=None):
|
651
675
|
R, C = generate_coords(img.shape, center)
|
676
|
+
if radius is None:
|
677
|
+
radius = R.shape[-1] // 2
|
652
678
|
M = R**2 + C**2
|
653
679
|
res = np.zeros_like(img)
|
654
680
|
res[M < radius**2] = img[M < radius**2]
|