nabu 2024.2.14__py3-none-any.whl → 2025.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doc/doc_config.py +32 -0
- nabu/__init__.py +1 -1
- nabu/app/bootstrap_stitching.py +4 -2
- nabu/app/cast_volume.py +16 -14
- nabu/app/cli_configs.py +102 -9
- nabu/app/compare_volumes.py +1 -1
- nabu/app/composite_cor.py +2 -4
- nabu/app/diag_to_pix.py +5 -6
- nabu/app/diag_to_rot.py +10 -11
- nabu/app/double_flatfield.py +18 -5
- nabu/app/estimate_motion.py +75 -0
- nabu/app/multicor.py +28 -15
- nabu/app/parse_reconstruction_log.py +1 -0
- nabu/app/pcaflats.py +122 -0
- nabu/app/prepare_weights_double.py +1 -2
- nabu/app/reconstruct.py +1 -7
- nabu/app/reconstruct_helical.py +5 -9
- nabu/app/reduce_dark_flat.py +5 -4
- nabu/app/rotate.py +3 -1
- nabu/app/stitching.py +7 -2
- nabu/app/tests/test_reduce_dark_flat.py +2 -2
- nabu/app/validator.py +1 -4
- nabu/cuda/convolution.py +1 -1
- nabu/cuda/fft.py +1 -1
- nabu/cuda/medfilt.py +1 -1
- nabu/cuda/padding.py +1 -1
- nabu/cuda/src/backproj.cu +6 -6
- nabu/cuda/src/cone.cu +4 -0
- nabu/cuda/src/hierarchical_backproj.cu +14 -0
- nabu/cuda/utils.py +2 -2
- nabu/estimation/alignment.py +17 -31
- nabu/estimation/cor.py +27 -33
- nabu/estimation/cor_sino.py +2 -8
- nabu/estimation/focus.py +4 -8
- nabu/estimation/motion.py +557 -0
- nabu/estimation/tests/test_alignment.py +2 -0
- nabu/estimation/tests/test_motion_estimation.py +471 -0
- nabu/estimation/tests/test_tilt.py +1 -1
- nabu/estimation/tilt.py +6 -5
- nabu/estimation/translation.py +47 -1
- nabu/io/cast_volume.py +108 -18
- nabu/io/detector_distortion.py +5 -6
- nabu/io/reader.py +45 -6
- nabu/io/reader_helical.py +5 -4
- nabu/io/tests/test_cast_volume.py +2 -2
- nabu/io/tests/test_readers.py +41 -38
- nabu/io/tests/test_remove_volume.py +152 -0
- nabu/io/tests/test_writers.py +2 -2
- nabu/io/utils.py +8 -4
- nabu/io/writer.py +1 -2
- nabu/misc/fftshift.py +1 -1
- nabu/misc/fourier_filters.py +1 -1
- nabu/misc/histogram.py +1 -1
- nabu/misc/histogram_cuda.py +1 -1
- nabu/misc/padding_base.py +1 -1
- nabu/misc/rotation.py +1 -1
- nabu/misc/rotation_cuda.py +1 -1
- nabu/misc/tests/test_binning.py +1 -1
- nabu/misc/transpose.py +1 -1
- nabu/misc/unsharp.py +1 -1
- nabu/misc/unsharp_cuda.py +1 -1
- nabu/misc/unsharp_opencl.py +1 -1
- nabu/misc/utils.py +1 -1
- nabu/opencl/fft.py +1 -1
- nabu/opencl/padding.py +1 -1
- nabu/opencl/src/backproj.cl +6 -6
- nabu/opencl/utils.py +8 -8
- nabu/pipeline/config.py +2 -2
- nabu/pipeline/config_validators.py +46 -46
- nabu/pipeline/datadump.py +3 -3
- nabu/pipeline/estimators.py +271 -11
- nabu/pipeline/fullfield/chunked.py +103 -67
- nabu/pipeline/fullfield/chunked_cuda.py +5 -2
- nabu/pipeline/fullfield/computations.py +4 -1
- nabu/pipeline/fullfield/dataset_validator.py +0 -1
- nabu/pipeline/fullfield/get_double_flatfield.py +147 -0
- nabu/pipeline/fullfield/nabu_config.py +36 -17
- nabu/pipeline/fullfield/processconfig.py +41 -7
- nabu/pipeline/fullfield/reconstruction.py +14 -10
- nabu/pipeline/helical/dataset_validator.py +3 -4
- nabu/pipeline/helical/fbp.py +4 -4
- nabu/pipeline/helical/filtering.py +5 -4
- nabu/pipeline/helical/gridded_accumulator.py +10 -11
- nabu/pipeline/helical/helical_chunked_regridded.py +1 -0
- nabu/pipeline/helical/helical_reconstruction.py +12 -9
- nabu/pipeline/helical/helical_utils.py +1 -2
- nabu/pipeline/helical/nabu_config.py +2 -1
- nabu/pipeline/helical/span_strategy.py +1 -0
- nabu/pipeline/helical/weight_balancer.py +2 -3
- nabu/pipeline/params.py +20 -3
- nabu/pipeline/tests/__init__.py +0 -0
- nabu/pipeline/tests/test_estimators.py +240 -3
- nabu/pipeline/utils.py +1 -1
- nabu/pipeline/writer.py +1 -1
- nabu/preproc/alignment.py +0 -10
- nabu/preproc/ccd.py +53 -3
- nabu/preproc/ctf.py +8 -8
- nabu/preproc/ctf_cuda.py +1 -1
- nabu/preproc/double_flatfield_cuda.py +2 -2
- nabu/preproc/double_flatfield_variable_region.py +0 -1
- nabu/preproc/flatfield.py +307 -2
- nabu/preproc/flatfield_cuda.py +1 -2
- nabu/preproc/flatfield_variable_region.py +3 -3
- nabu/preproc/phase.py +2 -4
- nabu/preproc/phase_cuda.py +2 -2
- nabu/preproc/shift.py +4 -2
- nabu/preproc/shift_cuda.py +0 -1
- nabu/preproc/tests/test_ctf.py +4 -4
- nabu/preproc/tests/test_double_flatfield.py +1 -1
- nabu/preproc/tests/test_flatfield.py +1 -1
- nabu/preproc/tests/test_paganin.py +1 -3
- nabu/preproc/tests/test_pcaflats.py +154 -0
- nabu/preproc/tests/test_vshift.py +4 -1
- nabu/processing/azim.py +9 -5
- nabu/processing/convolution_cuda.py +6 -4
- nabu/processing/fft_base.py +7 -3
- nabu/processing/fft_cuda.py +25 -164
- nabu/processing/fft_opencl.py +28 -6
- nabu/processing/fftshift.py +1 -1
- nabu/processing/histogram.py +1 -1
- nabu/processing/muladd.py +0 -1
- nabu/processing/padding_base.py +1 -1
- nabu/processing/padding_cuda.py +0 -2
- nabu/processing/processing_base.py +12 -6
- nabu/processing/rotation_cuda.py +3 -1
- nabu/processing/tests/test_fft.py +2 -64
- nabu/processing/tests/test_fftshift.py +1 -1
- nabu/processing/tests/test_medfilt.py +1 -3
- nabu/processing/tests/test_padding.py +1 -1
- nabu/processing/tests/test_roll.py +1 -1
- nabu/processing/tests/test_rotation.py +4 -2
- nabu/processing/unsharp_opencl.py +1 -1
- nabu/reconstruction/astra.py +245 -0
- nabu/reconstruction/cone.py +39 -9
- nabu/reconstruction/fbp.py +7 -0
- nabu/reconstruction/fbp_base.py +36 -5
- nabu/reconstruction/filtering.py +59 -25
- nabu/reconstruction/filtering_cuda.py +22 -21
- nabu/reconstruction/filtering_opencl.py +10 -14
- nabu/reconstruction/hbp.py +26 -13
- nabu/reconstruction/mlem.py +55 -16
- nabu/reconstruction/projection.py +3 -5
- nabu/reconstruction/sinogram.py +1 -1
- nabu/reconstruction/sinogram_cuda.py +0 -1
- nabu/reconstruction/tests/test_cone.py +37 -2
- nabu/reconstruction/tests/test_deringer.py +4 -4
- nabu/reconstruction/tests/test_fbp.py +36 -15
- nabu/reconstruction/tests/test_filtering.py +27 -7
- nabu/reconstruction/tests/test_halftomo.py +28 -2
- nabu/reconstruction/tests/test_mlem.py +94 -64
- nabu/reconstruction/tests/test_projector.py +7 -2
- nabu/reconstruction/tests/test_reconstructor.py +1 -1
- nabu/reconstruction/tests/test_sino_normalization.py +0 -1
- nabu/resources/dataset_analyzer.py +210 -24
- nabu/resources/gpu.py +4 -4
- nabu/resources/logger.py +4 -4
- nabu/resources/nxflatfield.py +103 -37
- nabu/resources/tests/test_dataset_analyzer.py +37 -0
- nabu/resources/tests/test_extract.py +11 -0
- nabu/resources/tests/test_nxflatfield.py +5 -5
- nabu/resources/utils.py +16 -10
- nabu/stitching/alignment.py +8 -11
- nabu/stitching/config.py +44 -35
- nabu/stitching/definitions.py +2 -2
- nabu/stitching/frame_composition.py +8 -10
- nabu/stitching/overlap.py +4 -4
- nabu/stitching/sample_normalization.py +5 -5
- nabu/stitching/slurm_utils.py +2 -2
- nabu/stitching/stitcher/base.py +2 -0
- nabu/stitching/stitcher/dumper/base.py +0 -1
- nabu/stitching/stitcher/dumper/postprocessing.py +1 -1
- nabu/stitching/stitcher/post_processing.py +11 -9
- nabu/stitching/stitcher/pre_processing.py +37 -31
- nabu/stitching/stitcher/single_axis.py +2 -3
- nabu/stitching/stitcher_2D.py +2 -1
- nabu/stitching/tests/test_config.py +10 -11
- nabu/stitching/tests/test_sample_normalization.py +1 -1
- nabu/stitching/tests/test_slurm_utils.py +1 -2
- nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
- nabu/stitching/tests/test_z_postprocessing_stitching.py +3 -3
- nabu/stitching/tests/test_z_preprocessing_stitching.py +27 -24
- nabu/stitching/utils/tests/__init__.py +0 -0
- nabu/stitching/utils/tests/test_post-processing.py +1 -0
- nabu/stitching/utils/utils.py +16 -18
- nabu/tests.py +0 -3
- nabu/testutils.py +62 -9
- nabu/utils.py +50 -20
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/METADATA +7 -7
- nabu-2025.1.0.dist-info/RECORD +328 -0
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/WHEEL +1 -1
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/entry_points.txt +2 -1
- nabu/app/correct_rot.py +0 -70
- nabu/io/tests/test_detector_distortion.py +0 -178
- nabu-2024.2.14.dist-info/RECORD +0 -317
- /nabu/{stitching → app}/tests/__init__.py +0 -0
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/licenses/LICENSE +0 -0
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
import h5py
|
|
6
|
+
from nabu.testutils import utilstest
|
|
7
|
+
from nabu.preproc.flatfield import (
|
|
8
|
+
PCAFlatsDecomposer,
|
|
9
|
+
PCAFlatsNormalizer,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@pytest.fixture(scope="class")
|
|
14
|
+
def bootstrap_pcaflats(request):
|
|
15
|
+
cls = request.cls
|
|
16
|
+
# TODO: these tolerances for having the tests passed should be tighter.
|
|
17
|
+
# Discrepancies between id11 code and nabu code are still mysterious.
|
|
18
|
+
cls.mean_abs_tol = 1e-1
|
|
19
|
+
cls.comps_abs_tol = 1e-2
|
|
20
|
+
cls.projs, cls.flats, cls.darks = get_pcaflats_data("test_pcaflats.npz")
|
|
21
|
+
cls.raw_projs = cls.projs.copy() # Needed because flat correction is done inplace.
|
|
22
|
+
ref_data = get_pcaflats_refdata("ref_pcaflats.npz")
|
|
23
|
+
cls.mean = ref_data["mean"]
|
|
24
|
+
cls.components_3 = ref_data["components_3"]
|
|
25
|
+
cls.components_15 = ref_data["components_15"]
|
|
26
|
+
cls.dark = ref_data["dark"]
|
|
27
|
+
cls.normalized_projs_3 = ref_data["normalized_projs_3"]
|
|
28
|
+
cls.normalized_projs_15 = ref_data["normalized_projs_15"]
|
|
29
|
+
cls.normalized_projs_custom_mask = ref_data["normalized_projs_custom_mask"]
|
|
30
|
+
cls.test_normalize_projs_custom_prop = ref_data["normalized_projs_custom_prop"]
|
|
31
|
+
|
|
32
|
+
cls.h5_filename_3 = get_h5_pcaflats("pcaflat_3.h5")
|
|
33
|
+
cls.h5_filename_15 = get_h5_pcaflats("pcaflat_15.h5")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def get_pcaflats_data(*dataset_path):
|
|
37
|
+
"""
|
|
38
|
+
Get a dataset file from silx.org/pub/nabu/data
|
|
39
|
+
dataset_args is a list describing a nested folder structures, ex.
|
|
40
|
+
["path", "to", "my", "dataset.h5"]
|
|
41
|
+
"""
|
|
42
|
+
dataset_relpath = os.path.join(*dataset_path)
|
|
43
|
+
dataset_downloaded_path = utilstest.getfile(dataset_relpath)
|
|
44
|
+
data = np.load(dataset_downloaded_path)
|
|
45
|
+
projs = data["projs"].astype(np.float32)
|
|
46
|
+
flats = data["flats"].astype(np.float32)
|
|
47
|
+
darks = data["darks"].astype(np.float32)
|
|
48
|
+
|
|
49
|
+
return projs, flats, darks
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_h5_pcaflats(*dataset_path):
|
|
53
|
+
"""
|
|
54
|
+
Get a dataset file from silx.org/pub/nabu/data
|
|
55
|
+
dataset_args is a list describing a nested folder structures, ex.
|
|
56
|
+
["path", "to", "my", "dataset.h5"]
|
|
57
|
+
"""
|
|
58
|
+
dataset_relpath = os.path.join(*dataset_path)
|
|
59
|
+
dataset_downloaded_path = utilstest.getfile(dataset_relpath)
|
|
60
|
+
|
|
61
|
+
return dataset_downloaded_path
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def get_pcaflats_refdata(*dataset_path):
|
|
65
|
+
"""
|
|
66
|
+
Get a dataset file from silx.org/pub/nabu/data
|
|
67
|
+
dataset_args is a list describing a nested folder structures, ex.
|
|
68
|
+
["path", "to", "my", "dataset.h5"]
|
|
69
|
+
"""
|
|
70
|
+
dataset_relpath = os.path.join(*dataset_path)
|
|
71
|
+
dataset_downloaded_path = utilstest.getfile(dataset_relpath)
|
|
72
|
+
data = np.load(dataset_downloaded_path)
|
|
73
|
+
|
|
74
|
+
return data
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def get_decomposition(filename):
|
|
78
|
+
with h5py.File(filename, "r") as f:
|
|
79
|
+
# Load the dataset
|
|
80
|
+
p_comps = f["entry0000/p_components"][()]
|
|
81
|
+
p_mean = f["entry0000/p_mean"][()]
|
|
82
|
+
dark = f["entry0000/dark"][()]
|
|
83
|
+
return p_comps, p_mean, dark
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@pytest.mark.usefixtures("bootstrap_pcaflats")
|
|
87
|
+
class TestPCAFlatsDecomposer:
|
|
88
|
+
def test_decompose_flats(self):
|
|
89
|
+
# Build 3-sigma basis
|
|
90
|
+
pca = PCAFlatsDecomposer(self.flats, self.darks, nsigma=3)
|
|
91
|
+
message = f"Found a discrepency between computed mean flat and reference."
|
|
92
|
+
assert np.allclose(self.mean, pca.mean, atol=self.mean_abs_tol), message
|
|
93
|
+
message = f"Found a discrepency between computed components and reference ones if nsigma=3."
|
|
94
|
+
assert np.allclose(self.components_3, np.array(pca.components), atol=self.comps_abs_tol), message
|
|
95
|
+
|
|
96
|
+
# Build 1.5-sigma basis
|
|
97
|
+
pca = PCAFlatsDecomposer(self.flats, self.darks, nsigma=1.5)
|
|
98
|
+
message = f"Found a discrepency between computed components and reference ones, if nsigma=1.5."
|
|
99
|
+
assert np.allclose(self.components_15, np.array(pca.components), atol=self.comps_abs_tol), message
|
|
100
|
+
|
|
101
|
+
def test_save_load_decomposition(self):
|
|
102
|
+
pca = PCAFlatsDecomposer(self.flats, self.darks, nsigma=3)
|
|
103
|
+
tmp_path = os.path.join(os.path.dirname(self.h5_filename_3), "PCA_Flats.h5")
|
|
104
|
+
pca.save_decomposition(path=tmp_path)
|
|
105
|
+
p_comps, p_mean, dark = get_decomposition(tmp_path)
|
|
106
|
+
message = f"Found a discrepency between saved and loaded mean flat."
|
|
107
|
+
assert np.allclose(self.mean, p_mean, atol=self.mean_abs_tol), message
|
|
108
|
+
message = f"Found a discrepency between saved and loaded components if nsigma=3."
|
|
109
|
+
assert np.allclose(self.components_3, p_comps, atol=self.comps_abs_tol), message
|
|
110
|
+
message = f"Found a discrepency between saved and loaded dark."
|
|
111
|
+
assert np.allclose(self.dark, dark, atol=self.comps_abs_tol), message
|
|
112
|
+
# Clean up
|
|
113
|
+
if os.path.exists(tmp_path):
|
|
114
|
+
os.remove(tmp_path)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@pytest.mark.usefixtures("bootstrap_pcaflats")
|
|
118
|
+
class TestPCAFlatsNormalizer:
|
|
119
|
+
def test_load_pcaflats(self):
|
|
120
|
+
"""Tests that the structure of the output PCAFlat h5 file is correct."""
|
|
121
|
+
p_comps, p_mean, dark = get_decomposition(self.h5_filename_3)
|
|
122
|
+
# Check the shape of the loaded data
|
|
123
|
+
assert p_comps.shape[1:] == p_mean.shape
|
|
124
|
+
assert p_comps.shape[1:] == dark.shape
|
|
125
|
+
|
|
126
|
+
def test_normalize_projs(self):
|
|
127
|
+
p_comps, p_mean, dark = get_decomposition(self.h5_filename_3)
|
|
128
|
+
pca = PCAFlatsNormalizer(p_comps, dark, p_mean)
|
|
129
|
+
projs = self.raw_projs.copy()
|
|
130
|
+
pca.normalize_radios(projs)
|
|
131
|
+
assert np.allclose(projs, self.normalized_projs_3, atol=1e-2)
|
|
132
|
+
p_comps, p_mean, dark = get_decomposition(self.h5_filename_15)
|
|
133
|
+
pca = PCAFlatsNormalizer(p_comps, dark, p_mean)
|
|
134
|
+
projs = self.raw_projs.copy()
|
|
135
|
+
pca.normalize_radios(projs)
|
|
136
|
+
assert np.allclose(projs, self.normalized_projs_15, atol=1e-2)
|
|
137
|
+
|
|
138
|
+
def test_use_custom_mask(self):
|
|
139
|
+
mask = np.zeros(self.mean.shape, dtype=bool)
|
|
140
|
+
mask[:, :10] = True
|
|
141
|
+
mask[:, -10:] = True
|
|
142
|
+
p_comps, p_mean, dark = get_decomposition(self.h5_filename_3)
|
|
143
|
+
|
|
144
|
+
pca = PCAFlatsNormalizer(p_comps, dark, p_mean)
|
|
145
|
+
projs = self.raw_projs.copy()
|
|
146
|
+
pca.normalize_radios(projs, mask=mask)
|
|
147
|
+
assert np.allclose(projs, self.normalized_projs_custom_mask, atol=1e-2)
|
|
148
|
+
|
|
149
|
+
def test_change_mask_prop(self):
|
|
150
|
+
p_comps, p_mean, dark = get_decomposition(self.h5_filename_3)
|
|
151
|
+
pca = PCAFlatsNormalizer(p_comps, dark, p_mean)
|
|
152
|
+
projs = self.raw_projs.copy()
|
|
153
|
+
pca.normalize_radios(projs, prop=0.05)
|
|
154
|
+
assert np.allclose(projs, self.test_normalize_projs_custom_prop, atol=1e-2)
|
|
@@ -70,4 +70,7 @@ class TestVerticalShift:
|
|
|
70
70
|
Shifter_neg_cuda = CudaVerticalShift(d_radios.shape, -self.shifts)
|
|
71
71
|
Shifter_neg_cuda.apply_vertical_shifts(d_radios2, self.indexes)
|
|
72
72
|
err_max = np.max(np.abs(d_radios2.get() - radios2))
|
|
73
|
-
|
|
73
|
+
#
|
|
74
|
+
# FIXME tolerance was downgraded from 1e-6 to 8e-6 when switching to numpy 2
|
|
75
|
+
#
|
|
76
|
+
assert err_max < 8e-6, "Something wrong for negative translations: max error = %.2e" % err_max
|
nabu/processing/azim.py
CHANGED
|
@@ -96,8 +96,11 @@ def do_radial_distribution(ip, X0, Y0, mR, nBins=None, use_calibration=False, ca
|
|
|
96
96
|
Accumulator = np.zeros((2, nBins))
|
|
97
97
|
|
|
98
98
|
# Define the bounding box
|
|
99
|
-
|
|
100
|
-
|
|
99
|
+
height, width = ip.shape
|
|
100
|
+
xmin = max(int(X0 - mR), 0)
|
|
101
|
+
xmax = min(int(X0 + mR), width)
|
|
102
|
+
ymin = max(int(Y0 - mR), 0)
|
|
103
|
+
ymax = min(int(Y0 + mR), height)
|
|
101
104
|
|
|
102
105
|
# Create grid of coordinates
|
|
103
106
|
x = np.arange(xmin, xmax)
|
|
@@ -112,10 +115,11 @@ def do_radial_distribution(ip, X0, Y0, mR, nBins=None, use_calibration=False, ca
|
|
|
112
115
|
bins = np.clip(bins - 1, 0, nBins - 1) # Adjust bins to be in range [0, nBins-1]
|
|
113
116
|
|
|
114
117
|
# Accumulate values
|
|
118
|
+
sub_image = ip[xmin:xmax, ymin:ymax] # prevent issue on non-square images
|
|
115
119
|
for b in range(nBins):
|
|
116
120
|
mask = bins == b
|
|
117
121
|
Accumulator[0, b] = np.sum(mask)
|
|
118
|
-
Accumulator[1, b] = np.sum(
|
|
122
|
+
Accumulator[1, b] = np.sum(sub_image[mask])
|
|
119
123
|
|
|
120
124
|
# Normalize integrated intensity
|
|
121
125
|
Accumulator[1] /= Accumulator[0]
|
|
@@ -123,11 +127,11 @@ def do_radial_distribution(ip, X0, Y0, mR, nBins=None, use_calibration=False, ca
|
|
|
123
127
|
if use_calibration and cal is not None:
|
|
124
128
|
# Apply calibration if units are provided
|
|
125
129
|
radii = cal.pixel_width * mR * (np.arange(1, nBins + 1) / nBins)
|
|
126
|
-
units = cal.units
|
|
130
|
+
# units = cal.units
|
|
127
131
|
else:
|
|
128
132
|
# Use pixel units
|
|
129
133
|
radii = mR * (np.arange(1, nBins + 1) / nBins)
|
|
130
|
-
units = "pixels"
|
|
134
|
+
# units = "pixels"
|
|
131
135
|
|
|
132
136
|
if return_radii:
|
|
133
137
|
return radii, Accumulator[1]
|
|
@@ -159,7 +159,7 @@ class Convolution:
|
|
|
159
159
|
self.d_kernel = self.cuda.to_device("d_kernel", self.kernel)
|
|
160
160
|
else:
|
|
161
161
|
if not (isinstance(self.kernel, self.cuda.array_class)):
|
|
162
|
-
raise
|
|
162
|
+
raise TypeError("kernel must be either numpy array or pycuda array")
|
|
163
163
|
self.d_kernel = self.kernel
|
|
164
164
|
self._old_input_ref = None
|
|
165
165
|
self._old_output_ref = None
|
|
@@ -185,7 +185,7 @@ class Convolution:
|
|
|
185
185
|
self._c_conv_mode = mp[self.mode]
|
|
186
186
|
|
|
187
187
|
def _init_kernels(self):
|
|
188
|
-
if self.kernel_ndim > 1:
|
|
188
|
+
if self.kernel_ndim > 1: # noqa: SIM102
|
|
189
189
|
if np.abs(np.diff(self.kernel.shape)).max() > 0:
|
|
190
190
|
raise NotImplementedError("Non-separable convolution with non-square kernels is not implemented yet")
|
|
191
191
|
# Compile source module
|
|
@@ -290,7 +290,7 @@ class Convolution:
|
|
|
290
290
|
return ndim
|
|
291
291
|
|
|
292
292
|
def _check_array(self, arr):
|
|
293
|
-
if not (isinstance(arr, self.cuda.array_class) or isinstance(arr, np.ndarray)):
|
|
293
|
+
if not (isinstance(arr, self.cuda.array_class) or isinstance(arr, np.ndarray)): # noqa: SIM101
|
|
294
294
|
raise TypeError("Expected either pycuda.gpuarray or numpy.ndarray")
|
|
295
295
|
if arr.dtype != np.float32:
|
|
296
296
|
raise TypeError("Data must be float32")
|
|
@@ -305,7 +305,7 @@ class Convolution:
|
|
|
305
305
|
self._old_input_ref = self.data_in
|
|
306
306
|
self.data_in = array
|
|
307
307
|
data_in_ref = self.data_in
|
|
308
|
-
if output is not None:
|
|
308
|
+
if output is not None: # noqa: SIM102
|
|
309
309
|
if not (isinstance(output, np.ndarray)):
|
|
310
310
|
self._old_output_ref = self.data_out
|
|
311
311
|
self.data_out = output
|
|
@@ -324,11 +324,13 @@ class Convolution:
|
|
|
324
324
|
cuda_kernel = self.cuda_kernels[axis]
|
|
325
325
|
cuda_kernel_args = self._configure_kernel_args(self.kernel_args, input_ref, output_ref)
|
|
326
326
|
ev = cuda_kernel.prepared_call(*cuda_kernel_args)
|
|
327
|
+
return ev
|
|
327
328
|
|
|
328
329
|
def _nd_convolution(self):
|
|
329
330
|
assert len(self.use_case_kernels) == 1
|
|
330
331
|
cuda_kernel = self._module.get_function(self.use_case_kernels[0])
|
|
331
332
|
ev = cuda_kernel.prepared_call(*self.kernel_args)
|
|
333
|
+
return ev
|
|
332
334
|
|
|
333
335
|
def _recover_arrays_references(self):
|
|
334
336
|
if self._old_input_ref is not None:
|
nabu/processing/fft_base.py
CHANGED
|
@@ -35,7 +35,7 @@ class _BaseFFT:
|
|
|
35
35
|
the transform is unitary. Both FFT and IFFT are scaled with 1/sqrt(N).
|
|
36
36
|
* "none": no normalizatio is done : IFFT(FFT(data)) = data*N
|
|
37
37
|
|
|
38
|
-
Other
|
|
38
|
+
Other Parameters
|
|
39
39
|
-----------------
|
|
40
40
|
backend_options: dict, optional
|
|
41
41
|
Parameters to pass to CudaProcessing or OpenCLProcessing class.
|
|
@@ -93,6 +93,10 @@ class _BaseFFT:
|
|
|
93
93
|
pass
|
|
94
94
|
|
|
95
95
|
|
|
96
|
+
def raise_base_class_error(slf, *args, **kwargs):
|
|
97
|
+
raise ValueError
|
|
98
|
+
|
|
99
|
+
|
|
96
100
|
class _BaseVKFFT(_BaseFFT):
|
|
97
101
|
"""
|
|
98
102
|
FFT using VKFFT backend
|
|
@@ -101,7 +105,7 @@ class _BaseVKFFT(_BaseFFT):
|
|
|
101
105
|
implem = "vkfft"
|
|
102
106
|
backend = "none"
|
|
103
107
|
ProcessingCls = BaseClassError
|
|
104
|
-
|
|
108
|
+
get_fft_obj = raise_base_class_error
|
|
105
109
|
|
|
106
110
|
def _configure_batched_transform(self):
|
|
107
111
|
if self.axes is not None and len(self.shape) == len(self.axes):
|
|
@@ -128,7 +132,7 @@ class _BaseVKFFT(_BaseFFT):
|
|
|
128
132
|
self._vkfft_ndim = None
|
|
129
133
|
|
|
130
134
|
def _compute_fft_plans(self):
|
|
131
|
-
self._vkfft_plan = self.
|
|
135
|
+
self._vkfft_plan = self.get_fft_obj(
|
|
132
136
|
self.shape,
|
|
133
137
|
self.dtype,
|
|
134
138
|
ndim=self._vkfft_ndim,
|
nabu/processing/fft_cuda.py
CHANGED
|
@@ -1,148 +1,33 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import warnings
|
|
3
|
+
from functools import lru_cache
|
|
3
4
|
from multiprocessing import get_context
|
|
4
5
|
from multiprocessing.pool import Pool
|
|
5
|
-
import
|
|
6
|
-
from
|
|
7
|
-
from .fft_base import _BaseFFT, _BaseVKFFT
|
|
6
|
+
from ..utils import BaseClassError, check_supported, no_decorator
|
|
7
|
+
from .fft_base import _BaseVKFFT
|
|
8
8
|
|
|
9
9
|
try:
|
|
10
|
-
from pyvkfft.cuda import VkFFTApp as
|
|
10
|
+
from pyvkfft.cuda import VkFFTApp as CudaVkFFTApp
|
|
11
11
|
|
|
12
12
|
__has_vkfft__ = True
|
|
13
13
|
except (ImportError, OSError):
|
|
14
14
|
__has_vkfft__ = False
|
|
15
|
-
|
|
15
|
+
CudaVkFFTApp = BaseClassError
|
|
16
16
|
from ..cuda.processing import CudaProcessing
|
|
17
17
|
|
|
18
|
-
|
|
19
|
-
cu_fft = None
|
|
20
|
-
cu_ifft = None
|
|
21
|
-
__has_skcuda__ = None
|
|
18
|
+
n_cached_ffts = int(os.getenv("NABU_FFT_CACHE", "0"))
|
|
22
19
|
|
|
23
20
|
|
|
24
|
-
|
|
25
|
-
# This needs to be done here, because scikit-cuda creates a Cuda context at import,
|
|
26
|
-
# which can mess things up in some cases.
|
|
27
|
-
# Ugly solution to an ugly problem.
|
|
28
|
-
global __has_skcuda__, Plan, cu_fft, cu_ifft
|
|
29
|
-
try:
|
|
30
|
-
from skcuda.fft import Plan
|
|
31
|
-
from skcuda.fft import fft as cu_fft
|
|
32
|
-
from skcuda.fft import ifft as cu_ifft
|
|
21
|
+
maybe_cached = lru_cache(maxsize=n_cached_ffts) if n_cached_ffts > 0 else no_decorator
|
|
33
22
|
|
|
34
|
-
__has_skcuda__ = True
|
|
35
|
-
except ImportError:
|
|
36
|
-
__has_skcuda__ = False
|
|
37
23
|
|
|
24
|
+
@maybe_cached
|
|
25
|
+
def _get_vkfft_cuda(*args, **kwargs):
|
|
26
|
+
return CudaVkFFTApp(*args, **kwargs)
|
|
38
27
|
|
|
39
|
-
class SKCUFFT(_BaseFFT):
|
|
40
|
-
implem = "skcuda"
|
|
41
|
-
backend = "cuda"
|
|
42
|
-
ProcessingCls = CudaProcessing
|
|
43
28
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
init_skcuda()
|
|
47
|
-
if not (__has_skcuda__):
|
|
48
|
-
raise ImportError("Please install pycuda and scikit-cuda to use the CUDA back-end")
|
|
49
|
-
|
|
50
|
-
self.cufft_batch_size = 1
|
|
51
|
-
self.cufft_shape = self.shape
|
|
52
|
-
self._cufft_plan_kwargs = {}
|
|
53
|
-
if (self.axes is not None) and (len(self.axes) < len(self.shape)):
|
|
54
|
-
# In the easiest case, the transform is computed along the fastest dimensions:
|
|
55
|
-
# - 1D transforms of lines of 2D data
|
|
56
|
-
# - 2D transforms of images of 3D data (stacked along slow dim)
|
|
57
|
-
# - 1D transforms of 3D data along fastest dim
|
|
58
|
-
# Otherwise, we have to configure cuda "advanced memory layout".
|
|
59
|
-
data_ndims = len(self.shape)
|
|
60
|
-
|
|
61
|
-
if data_ndims == 2:
|
|
62
|
-
n_y, n_x = self.shape
|
|
63
|
-
along_fast_dim = self.axes[0] == 1
|
|
64
|
-
self.cufft_shape = n_x if along_fast_dim else n_y
|
|
65
|
-
self.cufft_batch_size = n_y if along_fast_dim else n_x
|
|
66
|
-
if not (along_fast_dim):
|
|
67
|
-
# Batched vertical 1D FFT on 2D data need advanced data layout
|
|
68
|
-
# http://docs.nvidia.com/cuda/cufft/#advanced-data-layout
|
|
69
|
-
self._cufft_plan_kwargs = {
|
|
70
|
-
"inembed": np.int32([0]),
|
|
71
|
-
"istride": n_x,
|
|
72
|
-
"idist": 1,
|
|
73
|
-
"onembed": np.int32([0]),
|
|
74
|
-
"ostride": n_x,
|
|
75
|
-
"odist": 1,
|
|
76
|
-
}
|
|
77
|
-
|
|
78
|
-
if data_ndims == 3:
|
|
79
|
-
# TODO/FIXME - the following work for C2C but not R2C ?!
|
|
80
|
-
# fast_axes = [(1, 2), (2, 1), (2,)]
|
|
81
|
-
fast_axes = [(2,)]
|
|
82
|
-
if self.axes not in fast_axes:
|
|
83
|
-
raise NotImplementedError(
|
|
84
|
-
"With the CUDA backend, batched transform on 3D data is only supported along fastest dimensions"
|
|
85
|
-
)
|
|
86
|
-
self.cufft_batch_size = self.shape[0]
|
|
87
|
-
self.cufft_shape = self.shape[1:]
|
|
88
|
-
if len(self.axes) == 1:
|
|
89
|
-
# 1D transform on 3D data: here only supported along fast dim, so batch_size is Nx*Ny
|
|
90
|
-
self.cufft_batch_size = np.prod(self.shape[:2])
|
|
91
|
-
self.cufft_shape = (self.shape[-1],)
|
|
92
|
-
if len(self.cufft_shape) == 1:
|
|
93
|
-
self.cufft_shape = self.cufft_shape[0]
|
|
94
|
-
|
|
95
|
-
def _configure_normalization(self, normalize):
|
|
96
|
-
self.normalize = normalize
|
|
97
|
-
if self.normalize == "ortho":
|
|
98
|
-
# TODO
|
|
99
|
-
raise NotImplementedError("Normalization mode 'ortho' is not implemented with CUDA backend yet.")
|
|
100
|
-
self.cufft_scale_inverse = self.normalize == "rescale"
|
|
101
|
-
|
|
102
|
-
def _compute_fft_plans(self):
|
|
103
|
-
self.plan_forward = Plan( # pylint: disable = E1102
|
|
104
|
-
self.cufft_shape,
|
|
105
|
-
self.dtype,
|
|
106
|
-
self.dtype_out,
|
|
107
|
-
batch=self.cufft_batch_size,
|
|
108
|
-
stream=self.processing.stream,
|
|
109
|
-
**self._cufft_plan_kwargs,
|
|
110
|
-
# cufft extensible plan API is only supported after 0.5.1
|
|
111
|
-
# (commit 65288d28ca0b93e1234133f8d460dc6becb65121)
|
|
112
|
-
# but there is still no official 0.5.2
|
|
113
|
-
# ~ auto_allocate=True # cufft extensible plan API
|
|
114
|
-
)
|
|
115
|
-
self.plan_inverse = Plan( # pylint: disable = E1102
|
|
116
|
-
self.cufft_shape, # not shape_out
|
|
117
|
-
self.dtype_out,
|
|
118
|
-
self.dtype,
|
|
119
|
-
batch=self.cufft_batch_size,
|
|
120
|
-
stream=self.processing.stream,
|
|
121
|
-
**self._cufft_plan_kwargs,
|
|
122
|
-
# cufft extensible plan API is only supported after 0.5.1
|
|
123
|
-
# (commit 65288d28ca0b93e1234133f8d460dc6becb65121)
|
|
124
|
-
# but there is still no official 0.5.2
|
|
125
|
-
# ~ auto_allocate=True
|
|
126
|
-
)
|
|
127
|
-
|
|
128
|
-
def fft(self, array, output=None):
|
|
129
|
-
if output is None:
|
|
130
|
-
output = self.output_fft = self.processing.allocate_array(
|
|
131
|
-
"output_fft", self.shape_out, dtype=self.dtype_out
|
|
132
|
-
)
|
|
133
|
-
cu_fft(array, output, self.plan_forward, scale=False) # pylint: disable = E1102
|
|
134
|
-
return output
|
|
135
|
-
|
|
136
|
-
def ifft(self, array, output=None):
|
|
137
|
-
if output is None:
|
|
138
|
-
output = self.output_ifft = self.processing.allocate_array("output_ifft", self.shape, dtype=self.dtype)
|
|
139
|
-
cu_ifft( # pylint: disable = E1102
|
|
140
|
-
array,
|
|
141
|
-
output,
|
|
142
|
-
self.plan_inverse,
|
|
143
|
-
scale=self.cufft_scale_inverse,
|
|
144
|
-
)
|
|
145
|
-
return output
|
|
29
|
+
def get_vkfft_cuda(slf, *args, **kwargs):
|
|
30
|
+
return _get_vkfft_cuda(*args, **kwargs)
|
|
146
31
|
|
|
147
32
|
|
|
148
33
|
class VKCUFFT(_BaseVKFFT):
|
|
@@ -153,7 +38,7 @@ class VKCUFFT(_BaseVKFFT):
|
|
|
153
38
|
implem = "vkfft"
|
|
154
39
|
backend = "cuda"
|
|
155
40
|
ProcessingCls = CudaProcessing
|
|
156
|
-
|
|
41
|
+
get_fft_obj = get_vkfft_cuda
|
|
157
42
|
|
|
158
43
|
def _init_backend(self, backend_options):
|
|
159
44
|
super()._init_backend(backend_options)
|
|
@@ -167,13 +52,14 @@ def _has_vkfft(x):
|
|
|
167
52
|
|
|
168
53
|
if not __has_vkfft__:
|
|
169
54
|
return False
|
|
170
|
-
|
|
55
|
+
_ = VKCUFFT((16,), "f")
|
|
171
56
|
avail = True
|
|
172
57
|
except (ImportError, RuntimeError, OSError, NameError):
|
|
173
58
|
avail = False
|
|
174
59
|
return avail
|
|
175
60
|
|
|
176
61
|
|
|
62
|
+
@lru_cache(maxsize=2)
|
|
177
63
|
def has_vkfft(safe=True):
|
|
178
64
|
"""
|
|
179
65
|
Determine whether pyvkfft is available.
|
|
@@ -184,44 +70,20 @@ def has_vkfft(safe=True):
|
|
|
184
70
|
"""
|
|
185
71
|
if not safe:
|
|
186
72
|
return _has_vkfft(None)
|
|
187
|
-
ctx = get_context("spawn")
|
|
188
|
-
with Pool(1, context=ctx) as p:
|
|
189
|
-
v = p.map(_has_vkfft, [1])[0]
|
|
190
|
-
return v
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
def _has_skfft(x):
|
|
194
|
-
# should be run from within a Process
|
|
195
73
|
try:
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
def has_skcuda(safe=True):
|
|
206
|
-
"""
|
|
207
|
-
Determine whether scikit-cuda/CUFFT is available.
|
|
208
|
-
Currently, scikit-cuda will create a Cuda context for Cublas, which can mess up the current execution.
|
|
209
|
-
Do it in a separate thread.
|
|
210
|
-
"""
|
|
211
|
-
if not safe:
|
|
212
|
-
return _has_skfft(None)
|
|
213
|
-
ctx = get_context("spawn")
|
|
214
|
-
with Pool(1, context=ctx) as p:
|
|
215
|
-
v = p.map(_has_skfft, [1])[0]
|
|
74
|
+
ctx = get_context("spawn")
|
|
75
|
+
with Pool(1, context=ctx) as p:
|
|
76
|
+
v = p.map(_has_vkfft, [1])[0]
|
|
77
|
+
except AssertionError:
|
|
78
|
+
# Can get AssertionError: daemonic processes are not allowed to have children
|
|
79
|
+
# if the calling code is already a subprocess
|
|
80
|
+
return _has_vkfft(None)
|
|
216
81
|
return v
|
|
217
82
|
|
|
218
83
|
|
|
84
|
+
@lru_cache(maxsize=2)
|
|
219
85
|
def get_fft_class(backend="vkfft"):
|
|
220
86
|
backends = {
|
|
221
|
-
"scikit-cuda": SKCUFFT,
|
|
222
|
-
"skcuda": SKCUFFT,
|
|
223
|
-
"cufft": SKCUFFT,
|
|
224
|
-
"scikit": SKCUFFT,
|
|
225
87
|
"vkfft": VKCUFFT,
|
|
226
88
|
"pyvkfft": VKCUFFT,
|
|
227
89
|
}
|
|
@@ -237,7 +99,7 @@ def get_fft_class(backend="vkfft"):
|
|
|
237
99
|
|
|
238
100
|
avail_fft_implems = get_available_fft_implems()
|
|
239
101
|
if len(avail_fft_implems) == 0:
|
|
240
|
-
raise RuntimeError("Could not any Cuda FFT implementation. Please install
|
|
102
|
+
raise RuntimeError("Could not any Cuda FFT implementation. Please install pyvkfft")
|
|
241
103
|
if backend not in avail_fft_implems:
|
|
242
104
|
warnings.warn("Could not get FFT backend '%s'" % backend, RuntimeWarning)
|
|
243
105
|
backend = avail_fft_implems[0]
|
|
@@ -245,10 +107,9 @@ def get_fft_class(backend="vkfft"):
|
|
|
245
107
|
return get_fft_cls(backend)
|
|
246
108
|
|
|
247
109
|
|
|
110
|
+
@lru_cache(maxsize=1)
|
|
248
111
|
def get_available_fft_implems():
|
|
249
112
|
avail_implems = []
|
|
250
113
|
if has_vkfft(safe=True):
|
|
251
114
|
avail_implems.append("vkfft")
|
|
252
|
-
if has_skcuda(safe=True):
|
|
253
|
-
avail_implems.append("skcuda")
|
|
254
115
|
return avail_implems
|
nabu/processing/fft_opencl.py
CHANGED
|
@@ -1,15 +1,32 @@
|
|
|
1
|
+
from functools import lru_cache
|
|
2
|
+
import os
|
|
1
3
|
from multiprocessing import get_context
|
|
2
4
|
from multiprocessing.pool import Pool
|
|
5
|
+
|
|
6
|
+
from ..utils import BaseClassError, no_decorator
|
|
3
7
|
from .fft_base import _BaseVKFFT
|
|
4
8
|
from ..opencl.processing import OpenCLProcessing
|
|
5
9
|
|
|
6
10
|
try:
|
|
7
|
-
from pyvkfft.opencl import VkFFTApp as
|
|
11
|
+
from pyvkfft.opencl import VkFFTApp as OpenCLVkFFTApp
|
|
8
12
|
|
|
9
13
|
__has_vkfft__ = True
|
|
10
14
|
except (ImportError, OSError):
|
|
11
15
|
__has_vkfft__ = False
|
|
12
16
|
vk_clfft = None
|
|
17
|
+
OpenCLVkFFTApp = BaseClassError
|
|
18
|
+
|
|
19
|
+
n_cached_ffts = int(os.getenv("NABU_FFT_CACHE", "0"))
|
|
20
|
+
maybe_cached = lru_cache(maxsize=n_cached_ffts) if n_cached_ffts > 0 else no_decorator
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@maybe_cached
|
|
24
|
+
def _get_vkfft_opencl(*args, **kwargs):
|
|
25
|
+
return OpenCLVkFFTApp(*args, **kwargs)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_vkfft_opencl(slf, *args, **kwargs):
|
|
29
|
+
return _get_vkfft_opencl(*args, **kwargs)
|
|
13
30
|
|
|
14
31
|
|
|
15
32
|
class VKCLFFT(_BaseVKFFT):
|
|
@@ -20,7 +37,7 @@ class VKCLFFT(_BaseVKFFT):
|
|
|
20
37
|
implem = "vkfft"
|
|
21
38
|
backend = "opencl"
|
|
22
39
|
ProcessingCls = OpenCLProcessing
|
|
23
|
-
|
|
40
|
+
get_fft_obj = get_vkfft_opencl
|
|
24
41
|
|
|
25
42
|
def _init_backend(self, backend_options):
|
|
26
43
|
super()._init_backend(backend_options)
|
|
@@ -34,7 +51,7 @@ def _has_vkfft(x):
|
|
|
34
51
|
|
|
35
52
|
if not __has_vkfft__:
|
|
36
53
|
return False
|
|
37
|
-
|
|
54
|
+
_ = VKCLFFT((16,), "f")
|
|
38
55
|
avail = True
|
|
39
56
|
except (RuntimeError, OSError):
|
|
40
57
|
avail = False
|
|
@@ -48,7 +65,12 @@ def has_vkfft(safe=True):
|
|
|
48
65
|
"""
|
|
49
66
|
if not safe:
|
|
50
67
|
return _has_vkfft(None)
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
68
|
+
try:
|
|
69
|
+
ctx = get_context("spawn")
|
|
70
|
+
with Pool(1, context=ctx) as p:
|
|
71
|
+
v = p.map(_has_vkfft, [1])[0]
|
|
72
|
+
except AssertionError:
|
|
73
|
+
# Can get AssertionError: daemonic processes are not allowed to have children
|
|
74
|
+
# if the calling code is already a subprocess
|
|
75
|
+
return _has_vkfft(None)
|
|
54
76
|
return v
|
nabu/processing/fftshift.py
CHANGED
|
@@ -25,7 +25,7 @@ class FFTshiftBase:
|
|
|
25
25
|
axes: tuple, optional
|
|
26
26
|
Axes over which to shift. Default is None, which shifts all axes.
|
|
27
27
|
|
|
28
|
-
Other
|
|
28
|
+
Other Parameters
|
|
29
29
|
----------------
|
|
30
30
|
backend_options:
|
|
31
31
|
named arguments to pass to CudaProcessing or OpenCLProcessing
|
nabu/processing/histogram.py
CHANGED
|
@@ -146,7 +146,7 @@ class PartialHistogram:
|
|
|
146
146
|
elif self.bin_width == "uint16":
|
|
147
147
|
return self._bin_width_u16(dmin, dmax)
|
|
148
148
|
else:
|
|
149
|
-
raise ValueError
|
|
149
|
+
raise ValueError
|
|
150
150
|
|
|
151
151
|
def _compute_histogram_fixed_bw(self, data, data_range=None):
|
|
152
152
|
dmin, dmax = data.min(), data.max() if data_range is None else data_range
|
nabu/processing/muladd.py
CHANGED
nabu/processing/padding_base.py
CHANGED
nabu/processing/padding_cuda.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
from ..utils import get_cuda_srcfile, updiv
|
|
3
3
|
from ..cuda.processing import CudaProcessing
|
|
4
|
-
from ..cuda.utils import __has_pycuda__
|
|
5
4
|
from .padding_base import PaddingBase
|
|
6
5
|
|
|
7
6
|
|
|
@@ -12,7 +11,6 @@ class CudaPadding(PaddingBase):
|
|
|
12
11
|
|
|
13
12
|
backend = "cuda"
|
|
14
13
|
|
|
15
|
-
# TODO docstring from base class
|
|
16
14
|
def __init__(self, shape, pad_width, mode="constant", cuda_options=None, **kwargs):
|
|
17
15
|
super().__init__(shape, pad_width, mode=mode, **kwargs)
|
|
18
16
|
self.cuda_processing = self.processing = CudaProcessing(**(cuda_options or {}))
|