pytme 0.3b0__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +1 -5
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/match_template.py +177 -226
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/postprocess.py +69 -47
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocess.py +10 -23
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +98 -28
- pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/METADATA +15 -15
- pytme-0.3.1.dist-info/RECORD +133 -0
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +1 -0
- pytme-0.3.1.dist-info/licenses/LICENSE +339 -0
- scripts/estimate_memory_usage.py +1 -5
- scripts/eval.py +93 -0
- scripts/extract_candidates.py +118 -99
- scripts/match_template.py +177 -226
- scripts/match_template_filters.py +1200 -0
- scripts/postprocess.py +69 -47
- scripts/preprocess.py +10 -23
- scripts/preprocessor_gui.py +98 -28
- scripts/pytme_runner.py +1223 -0
- scripts/refine_matches.py +156 -387
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/preprocessing/test_frequency_filters.py +19 -10
- tests/preprocessing/test_utils.py +18 -0
- tests/test_analyzer.py +122 -122
- tests/test_backends.py +4 -9
- tests/test_density.py +0 -1
- tests/test_matching_cli.py +30 -30
- tests/test_matching_data.py +5 -5
- tests/test_matching_utils.py +11 -61
- tests/test_rotations.py +1 -1
- tme/__version__.py +1 -1
- tme/analyzer/__init__.py +1 -1
- tme/analyzer/_utils.py +5 -8
- tme/analyzer/aggregation.py +28 -9
- tme/analyzer/base.py +25 -36
- tme/analyzer/peaks.py +49 -122
- tme/analyzer/proxy.py +1 -0
- tme/backends/_jax_utils.py +31 -28
- tme/backends/_numpyfftw_utils.py +270 -0
- tme/backends/cupy_backend.py +11 -54
- tme/backends/jax_backend.py +72 -48
- tme/backends/matching_backend.py +6 -51
- tme/backends/mlx_backend.py +1 -27
- tme/backends/npfftw_backend.py +95 -90
- tme/backends/pytorch_backend.py +5 -26
- tme/density.py +7 -10
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/__init__.py +2 -2
- tme/filters/_utils.py +32 -7
- tme/filters/bandpass.py +225 -186
- tme/filters/ctf.py +138 -87
- tme/filters/reconstruction.py +38 -9
- tme/filters/wedge.py +98 -112
- tme/filters/whitening.py +1 -6
- tme/mask.py +341 -0
- tme/matching_data.py +20 -44
- tme/matching_exhaustive.py +46 -56
- tme/matching_optimization.py +2 -1
- tme/matching_scores.py +216 -412
- tme/matching_utils.py +82 -424
- tme/memory.py +1 -1
- tme/orientations.py +16 -8
- tme/parser.py +109 -29
- tme/preprocessor.py +2 -2
- tme/rotations.py +1 -1
- pytme-0.3b0.dist-info/RECORD +0 -122
- pytme-0.3b0.dist-info/licenses/LICENSE +0 -153
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
tme/filters/__init__.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
from .ctf import CTF, CTFReconstructed
|
2
2
|
from .compose import Compose, ComposableFilter
|
3
|
-
from .bandpass import
|
3
|
+
from .bandpass import BandPass, BandPassReconstructed
|
4
4
|
from .whitening import LinearWhiteningFilter
|
5
5
|
from .wedge import Wedge, WedgeReconstructed
|
6
|
-
from .reconstruction import ReconstructFromTilt
|
6
|
+
from .reconstruction import ReconstructFromTilt, ShiftFourier
|
tme/filters/_utils.py
CHANGED
@@ -15,6 +15,18 @@ from ..backends import NumpyFFTWBackend
|
|
15
15
|
from ..types import BackendArray, NDArray
|
16
16
|
from ..rotations import euler_to_rotationmatrix
|
17
17
|
|
18
|
+
__all__ = [
|
19
|
+
"compute_tilt_shape",
|
20
|
+
"centered_grid",
|
21
|
+
"frequency_grid_at_angle",
|
22
|
+
"fftfreqn",
|
23
|
+
"crop_real_fourier",
|
24
|
+
"compute_fourier_shape",
|
25
|
+
"shift_fourier",
|
26
|
+
"create_reconstruction_filter",
|
27
|
+
"pad_to_length",
|
28
|
+
]
|
29
|
+
|
18
30
|
|
19
31
|
def compute_tilt_shape(shape: Tuple[int], opening_axis: int, reduce_dim: bool = False):
|
20
32
|
"""
|
@@ -71,21 +83,27 @@ def frequency_grid_at_angle(
|
|
71
83
|
"""
|
72
84
|
Generate a frequency grid from 0 to 1/(2 * sampling_rate) in each axis.
|
73
85
|
|
86
|
+
Conceptually, this function generates accurate frequency grid of tilted
|
87
|
+
projections. Given a non-cubical shape, it no longer accurate to compute
|
88
|
+
frequences as Euclidean distances from a centered index grid. This function
|
89
|
+
solves this issue, and makes it possible to create complex filters on
|
90
|
+
non-cubical input shapes.
|
91
|
+
|
74
92
|
Parameters
|
75
93
|
----------
|
76
|
-
shape :
|
94
|
+
shape : tuple of int
|
77
95
|
The shape of the grid.
|
78
96
|
angle : float
|
79
97
|
The angle at which to generate the grid.
|
80
|
-
sampling_rate :
|
98
|
+
sampling_rate : tuple of float
|
81
99
|
The sampling rate for each dimension.
|
82
100
|
opening_axis : int, optional
|
83
|
-
The axis
|
101
|
+
The projection axis, defaults to None.
|
84
102
|
tilt_axis : int, optional
|
85
103
|
The axis along which the grid is tilted, defaults to None.
|
86
104
|
|
87
|
-
Returns
|
88
|
-
|
105
|
+
Returns
|
106
|
+
-------
|
89
107
|
NDArray
|
90
108
|
The frequency grid.
|
91
109
|
"""
|
@@ -231,7 +249,9 @@ def shift_fourier(
|
|
231
249
|
def create_reconstruction_filter(
|
232
250
|
filter_shape: Tuple[int], filter_type: str, **kwargs: Dict
|
233
251
|
):
|
234
|
-
"""
|
252
|
+
"""
|
253
|
+
Create a reconstruction filter of given filter_type. The DC component of
|
254
|
+
the filter will be located in the array center.
|
235
255
|
|
236
256
|
Parameters
|
237
257
|
----------
|
@@ -299,7 +319,7 @@ def create_reconstruction_filter(
|
|
299
319
|
ret = fftfreqn((size,), sampling_rate=1, compute_euclidean_norm=True)
|
300
320
|
min_increment = np.radians(np.min(np.abs(np.diff(np.sort(tilt_angles)))))
|
301
321
|
ret *= min_increment * size
|
302
|
-
np.fmin(ret, 1, out=ret)
|
322
|
+
ret = np.fmin(ret, 1, out=ret)
|
303
323
|
elif filter_type == "shepp-logan":
|
304
324
|
ret = freq * np.sinc(freq / 2)
|
305
325
|
elif filter_type == "cosine":
|
@@ -310,3 +330,8 @@ def create_reconstruction_filter(
|
|
310
330
|
raise ValueError("Unsupported filter type")
|
311
331
|
|
312
332
|
return ret
|
333
|
+
|
334
|
+
|
335
|
+
def pad_to_length(arr, length: int):
|
336
|
+
ret = np.atleast_1d(arr)
|
337
|
+
return np.repeat(ret, length // ret.size)
|
tme/filters/bandpass.py
CHANGED
@@ -8,225 +8,264 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
|
8
8
|
|
9
9
|
from typing import Tuple
|
10
10
|
from math import log, sqrt
|
11
|
+
from dataclasses import dataclass
|
12
|
+
|
13
|
+
import numpy as np
|
11
14
|
|
12
15
|
from ..types import BackendArray
|
13
16
|
from ..backends import backend as be
|
14
17
|
from .compose import ComposableFilter
|
15
|
-
from ._utils import
|
18
|
+
from ._utils import (
|
19
|
+
crop_real_fourier,
|
20
|
+
shift_fourier,
|
21
|
+
pad_to_length,
|
22
|
+
frequency_grid_at_angle,
|
23
|
+
fftfreqn,
|
24
|
+
)
|
16
25
|
|
17
|
-
__all__ = ["
|
26
|
+
__all__ = ["BandPass", "BandPassReconstructed"]
|
18
27
|
|
19
28
|
|
20
|
-
|
29
|
+
@dataclass
|
30
|
+
class BandPass(ComposableFilter):
|
21
31
|
"""
|
22
|
-
Generate
|
23
|
-
|
24
|
-
Parameters
|
25
|
-
----------
|
26
|
-
lowpass : float, optional
|
27
|
-
The lowpass cutoff, defaults to None.
|
28
|
-
highpass : float, optional
|
29
|
-
The highpass cutoff, defaults to None.
|
30
|
-
sampling_rate : Tuple[float], optional
|
31
|
-
The sampling r_position_to_molmapate in Fourier space, defaults to 1.
|
32
|
-
use_gaussian : bool, optional
|
33
|
-
Whether to use Gaussian bandpass filter, defaults to True.
|
34
|
-
return_real_fourier : bool, optional
|
35
|
-
Whether to return only the real Fourier space, defaults to False.
|
36
|
-
shape_is_real_fourier : bool, optional
|
37
|
-
Whether the shape represents the real Fourier space, defaults to False.
|
32
|
+
Generate per-slice Fourier Bandpass filter
|
38
33
|
"""
|
39
34
|
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
)
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
highpass: float,
|
61
|
-
sampling_rate: Tuple[float],
|
62
|
-
return_real_fourier: bool = False,
|
63
|
-
shape_is_real_fourier: bool = False,
|
64
|
-
**kwargs,
|
65
|
-
) -> BackendArray:
|
35
|
+
#: The tilt angles.
|
36
|
+
angles: Tuple[float]
|
37
|
+
#: The lowpass cutoffs. Either one or one per angle, defaults to None.
|
38
|
+
lowpass: Tuple[float] = None
|
39
|
+
#: The highpass cutoffs. Either one or one per angle, defaults to None.
|
40
|
+
highpass: Tuple[float] = None
|
41
|
+
#: The shape of the to-be created mask.
|
42
|
+
shape: Tuple[int] = None
|
43
|
+
#: Axis the plane is tilted over, defaults to 0 (x).
|
44
|
+
tilt_axis: int = 0
|
45
|
+
#: The projection axis, defaults to 2 (z).
|
46
|
+
opening_axis: int = 2
|
47
|
+
#: The sampling rate, defaults to 1 Ångstrom / voxel.
|
48
|
+
sampling_rate: Tuple[float] = 1
|
49
|
+
#: Whether to use Gaussian bandpass filter, defaults to True.
|
50
|
+
use_gaussian: bool = True
|
51
|
+
#: Whether to return a mask for rfft
|
52
|
+
return_real_fourier: bool = False
|
53
|
+
|
54
|
+
def __call__(self, **kwargs):
|
66
55
|
"""
|
67
|
-
|
68
|
-
|
69
|
-
Parameters
|
70
|
-
----------
|
71
|
-
shape : tuple of int
|
72
|
-
The shape of the bandpass filter.
|
73
|
-
lowpass : float
|
74
|
-
The lowpass cutoff in units of sampling rate.
|
75
|
-
highpass : float
|
76
|
-
The highpass cutoff in units of sampling rate.
|
77
|
-
return_real_fourier : bool, optional
|
78
|
-
Whether to return only the real Fourier space, defaults to False.
|
79
|
-
sampling_rate : float
|
80
|
-
The sampling rate in Fourier space.
|
81
|
-
shape_is_real_fourier : bool, optional
|
82
|
-
Whether the shape represents the real Fourier space, defaults to False.
|
83
|
-
**kwargs : dict
|
84
|
-
Additional keyword arguments.
|
85
|
-
|
86
|
-
Returns
|
87
|
-
-------
|
88
|
-
BackendArray
|
89
|
-
The bandpass filter in Fourier space.
|
56
|
+
Returns a Bandpass stack of chosen parameters with DC component in the center.
|
90
57
|
"""
|
58
|
+
func_args = vars(self).copy()
|
59
|
+
func_args.update(kwargs)
|
60
|
+
|
61
|
+
func = discrete_bandpass
|
62
|
+
if func_args.get("use_gaussian"):
|
63
|
+
func = gaussian_bandpass
|
64
|
+
|
65
|
+
return_real_fourier = kwargs.get("return_real_fourier", True)
|
66
|
+
shape_is_real_fourier = kwargs.get("shape_is_real_fourier", False)
|
91
67
|
if shape_is_real_fourier:
|
92
68
|
return_real_fourier = False
|
93
69
|
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
)
|
100
|
-
|
101
|
-
|
70
|
+
angles = np.atleast_1d(func_args["angles"])
|
71
|
+
_lowpass = pad_to_length(func_args["lowpass"], angles.size)
|
72
|
+
_highpass = pad_to_length(func_args["highpass"], angles.size)
|
73
|
+
|
74
|
+
masks = []
|
75
|
+
for index, angle in enumerate(angles):
|
76
|
+
frequency_grid = frequency_grid_at_angle(
|
77
|
+
shape=func_args["shape"],
|
78
|
+
tilt_axis=func_args["tilt_axis"],
|
79
|
+
opening_axis=func_args["opening_axis"],
|
80
|
+
angle=angle,
|
81
|
+
sampling_rate=1,
|
82
|
+
)
|
83
|
+
func_args["lowpass"] = _lowpass[index]
|
84
|
+
func_args["highpass"] = _highpass[index]
|
85
|
+
mask = func(grid=frequency_grid, **func_args)
|
102
86
|
|
103
|
-
|
104
|
-
|
105
|
-
|
87
|
+
mask = shift_fourier(data=mask, shape_is_real_fourier=shape_is_real_fourier)
|
88
|
+
if return_real_fourier:
|
89
|
+
mask = crop_real_fourier(mask)
|
90
|
+
masks.append(mask[None])
|
106
91
|
|
107
|
-
|
108
|
-
|
109
|
-
|
92
|
+
masks = be.concatenate(masks, axis=0)
|
93
|
+
return {
|
94
|
+
"data": be.to_backend_array(masks),
|
95
|
+
"shape": func_args["shape"],
|
96
|
+
"return_real_fourier": return_real_fourier,
|
97
|
+
"is_multiplicative_filter": True,
|
98
|
+
}
|
110
99
|
|
111
|
-
bandpass_filter = ((grid <= highcut) & (grid >= lowcut)) * 1.0
|
112
|
-
bandpass_filter = shift_fourier(
|
113
|
-
data=bandpass_filter, shape_is_real_fourier=shape_is_real_fourier
|
114
|
-
)
|
115
100
|
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
Additional keyword arguments.
|
150
|
-
|
151
|
-
Returns
|
152
|
-
-------
|
153
|
-
BackendArray
|
154
|
-
The bandpass filter in Fourier space.
|
155
|
-
"""
|
101
|
+
@dataclass
|
102
|
+
class BandPassReconstructed(ComposableFilter):
|
103
|
+
"""
|
104
|
+
Generate reconstructed bandpass filters in Fourier space.
|
105
|
+
"""
|
106
|
+
|
107
|
+
#: The lowpass cutoff, defaults to None.
|
108
|
+
lowpass: float = None
|
109
|
+
#: The highpass cutoff, defaults to None.
|
110
|
+
highpass: float = None
|
111
|
+
#: The shape of the to-be created mask.
|
112
|
+
shape: Tuple[int] = None
|
113
|
+
#: Axis the plane is tilted over, defaults to 0 (x).
|
114
|
+
tilt_axis: int = 0
|
115
|
+
#: The projection axis, defaults to 2 (z).
|
116
|
+
opening_axis: int = 2
|
117
|
+
#: The sampling rate, defaults to 1 Ångstrom / voxel.
|
118
|
+
sampling_rate: Tuple[float] = 1
|
119
|
+
#: Whether to use Gaussian bandpass filter, defaults to True.
|
120
|
+
use_gaussian: bool = True
|
121
|
+
#: Whether to return a mask for rfft
|
122
|
+
return_real_fourier: bool = False
|
123
|
+
|
124
|
+
def __call__(self, **kwargs):
|
125
|
+
func_args = vars(self).copy()
|
126
|
+
func_args.update(kwargs)
|
127
|
+
|
128
|
+
func = discrete_bandpass
|
129
|
+
if func_args.get("use_gaussian"):
|
130
|
+
func = gaussian_bandpass
|
131
|
+
|
132
|
+
return_real_fourier = func_args.get("return_real_fourier", True)
|
133
|
+
shape_is_real_fourier = func_args.get("shape_is_real_fourier", False)
|
156
134
|
if shape_is_real_fourier:
|
157
135
|
return_real_fourier = False
|
158
136
|
|
159
137
|
grid = fftfreqn(
|
160
|
-
shape=shape,
|
138
|
+
shape=func_args["shape"],
|
161
139
|
sampling_rate=0.5,
|
162
140
|
shape_is_real_fourier=shape_is_real_fourier,
|
163
141
|
compute_euclidean_norm=True,
|
164
142
|
)
|
165
|
-
|
166
|
-
grid = -be.square(grid, out=grid)
|
143
|
+
mask = func(grid=grid, **func_args)
|
167
144
|
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
be.max(be.multiply(2, be.to_backend_array(sampling_rate)))
|
172
|
-
)
|
145
|
+
mask = shift_fourier(data=mask, shape_is_real_fourier=shape_is_real_fourier)
|
146
|
+
if return_real_fourier:
|
147
|
+
mask = crop_real_fourier(mask)
|
173
148
|
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
if has_lowpass:
|
182
|
-
lowpass = upper_sampling / (lowpass * norm)
|
183
|
-
lowpass = be.multiply(2, be.square(lowpass))
|
184
|
-
if not has_highpass:
|
185
|
-
lowpass_filter = be.divide(grid, lowpass, out=grid)
|
186
|
-
else:
|
187
|
-
lowpass_filter = be.divide(grid, lowpass)
|
188
|
-
lowpass_filter = be.exp(lowpass_filter, out=lowpass_filter)
|
189
|
-
|
190
|
-
if has_highpass:
|
191
|
-
highpass = upper_sampling / (highpass * norm)
|
192
|
-
highpass = be.multiply(2, be.square(highpass))
|
193
|
-
highpass_filter = be.divide(grid, highpass, out=grid)
|
194
|
-
highpass_filter = be.exp(highpass_filter, out=highpass_filter)
|
195
|
-
highpass_filter = be.subtract(1, highpass_filter, out=highpass_filter)
|
196
|
-
|
197
|
-
if has_lowpass and not has_highpass:
|
198
|
-
bandpass_filter = lowpass_filter
|
199
|
-
elif not has_lowpass and has_highpass:
|
200
|
-
bandpass_filter = highpass_filter
|
201
|
-
elif has_lowpass and has_highpass:
|
202
|
-
bandpass_filter = be.multiply(
|
203
|
-
lowpass_filter, highpass_filter, out=lowpass_filter
|
204
|
-
)
|
205
|
-
else:
|
206
|
-
bandpass_filter = be.full(shape, fill_value=1, dtype=be._float_dtype)
|
149
|
+
return {
|
150
|
+
"data": be.to_backend_array(mask),
|
151
|
+
"shape": func_args["shape"],
|
152
|
+
"return_real_fourier": return_real_fourier,
|
153
|
+
"is_multiplicative_filter": True,
|
154
|
+
}
|
207
155
|
|
208
|
-
bandpass_filter = shift_fourier(
|
209
|
-
data=bandpass_filter, shape_is_real_fourier=shape_is_real_fourier
|
210
|
-
)
|
211
156
|
|
212
|
-
|
213
|
-
|
157
|
+
def discrete_bandpass(
|
158
|
+
grid: BackendArray,
|
159
|
+
lowpass: float,
|
160
|
+
highpass: float,
|
161
|
+
sampling_rate: Tuple[float],
|
162
|
+
**kwargs,
|
163
|
+
) -> BackendArray:
|
164
|
+
"""
|
165
|
+
Generate a bandpass filter using discrete frequency cutoffs.
|
214
166
|
|
215
|
-
|
167
|
+
Parameters
|
168
|
+
----------
|
169
|
+
grid : BackendArray
|
170
|
+
Frequencies in Fourier space.
|
171
|
+
lowpass : float
|
172
|
+
The lowpass cutoff in units of sampling rate.
|
173
|
+
highpass : float
|
174
|
+
The highpass cutoff in units of sampling rate.
|
175
|
+
return_real_fourier : bool, optional
|
176
|
+
Whether to return only the real Fourier space, defaults to False.
|
177
|
+
sampling_rate : float
|
178
|
+
The sampling rate in Fourier space.
|
179
|
+
**kwargs : dict
|
180
|
+
Additional keyword arguments.
|
181
|
+
|
182
|
+
Returns
|
183
|
+
-------
|
184
|
+
BackendArray
|
185
|
+
The bandpass filter in Fourier space.
|
186
|
+
"""
|
187
|
+
grid = be.astype(be.to_backend_array(grid), be._float_dtype)
|
188
|
+
sampling_rate = be.to_backend_array(sampling_rate)
|
216
189
|
|
217
|
-
|
218
|
-
|
219
|
-
|
190
|
+
highcut = grid.max()
|
191
|
+
if lowpass is not None:
|
192
|
+
highcut = be.max(2 * sampling_rate / lowpass)
|
220
193
|
|
221
|
-
|
222
|
-
|
223
|
-
|
194
|
+
lowcut = 0
|
195
|
+
if highpass is not None:
|
196
|
+
lowcut = be.max(2 * sampling_rate / highpass)
|
224
197
|
|
225
|
-
|
198
|
+
bandpass_filter = ((grid <= highcut) & (grid >= lowcut)) * 1.0
|
199
|
+
return bandpass_filter
|
226
200
|
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
201
|
+
|
202
|
+
def gaussian_bandpass(
|
203
|
+
grid: BackendArray,
|
204
|
+
lowpass: float = None,
|
205
|
+
highpass: float = None,
|
206
|
+
sampling_rate: float = 1,
|
207
|
+
**kwargs,
|
208
|
+
) -> BackendArray:
|
209
|
+
"""
|
210
|
+
Generate a bandpass filter using Gaussians.
|
211
|
+
|
212
|
+
Parameters
|
213
|
+
----------
|
214
|
+
grid : BackendArray
|
215
|
+
Frequency grid in Fourier space.
|
216
|
+
lowpass : float, optional
|
217
|
+
The lowpass cutoff in units of sampling rate, defaults to None.
|
218
|
+
highpass : float, optional
|
219
|
+
The highpass cutoff in units of sampling rate, defaults to None.
|
220
|
+
sampling_rate : float, optional
|
221
|
+
The sampling rate in Fourier space, defaults to one.
|
222
|
+
**kwargs : dict
|
223
|
+
Additional keyword arguments.
|
224
|
+
|
225
|
+
Returns
|
226
|
+
-------
|
227
|
+
BackendArray
|
228
|
+
The bandpass filter in Fourier space.
|
229
|
+
"""
|
230
|
+
grid = be.astype(be.to_backend_array(grid), be._float_dtype)
|
231
|
+
grid = -be.square(grid, out=grid)
|
232
|
+
|
233
|
+
has_lowpass, has_highpass = False, False
|
234
|
+
norm = float(sqrt(2 * log(2)))
|
235
|
+
upper_sampling = float(be.max(be.multiply(2, be.to_backend_array(sampling_rate))))
|
236
|
+
|
237
|
+
if lowpass is not None:
|
238
|
+
lowpass, has_lowpass = float(lowpass), True
|
239
|
+
lowpass = be.maximum(lowpass, be.eps(be._float_dtype))
|
240
|
+
if highpass is not None:
|
241
|
+
highpass, has_highpass = float(highpass), True
|
242
|
+
highpass = be.maximum(highpass, be.eps(be._float_dtype))
|
243
|
+
|
244
|
+
if has_lowpass:
|
245
|
+
lowpass = upper_sampling / (lowpass * norm)
|
246
|
+
lowpass = be.multiply(2, be.square(lowpass))
|
247
|
+
if not has_highpass:
|
248
|
+
lowpass_filter = be.divide(grid, lowpass, out=grid)
|
249
|
+
else:
|
250
|
+
lowpass_filter = be.divide(grid, lowpass)
|
251
|
+
lowpass_filter = be.exp(lowpass_filter, out=lowpass_filter)
|
252
|
+
|
253
|
+
if has_highpass:
|
254
|
+
highpass = upper_sampling / (highpass * norm)
|
255
|
+
highpass = be.multiply(2, be.square(highpass))
|
256
|
+
highpass_filter = be.divide(grid, highpass, out=grid)
|
257
|
+
highpass_filter = be.exp(highpass_filter, out=highpass_filter)
|
258
|
+
highpass_filter = be.subtract(1, highpass_filter, out=highpass_filter)
|
259
|
+
|
260
|
+
if has_lowpass and not has_highpass:
|
261
|
+
bandpass_filter = lowpass_filter
|
262
|
+
elif not has_lowpass and has_highpass:
|
263
|
+
bandpass_filter = highpass_filter
|
264
|
+
elif has_lowpass and has_highpass:
|
265
|
+
bandpass_filter = be.multiply(
|
266
|
+
lowpass_filter, highpass_filter, out=lowpass_filter
|
267
|
+
)
|
268
|
+
else:
|
269
|
+
bandpass_filter = be.full(grid.shape, fill_value=1, dtype=be._float_dtype)
|
270
|
+
|
271
|
+
return bandpass_filter
|