pytme 0.3.1.post2__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.2.data/scripts/estimate_ram_usage.py +97 -0
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/match_template.py +213 -196
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/postprocess.py +40 -78
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/preprocess.py +4 -5
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/preprocessor_gui.py +49 -103
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/pytme_runner.py +46 -69
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/METADATA +3 -2
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/RECORD +68 -65
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +213 -196
- scripts/match_template_devel.py +1339 -0
- scripts/postprocess.py +40 -78
- scripts/preprocess.py +4 -5
- scripts/preprocessor_gui.py +49 -103
- scripts/pytme_runner.py +46 -69
- tests/preprocessing/test_compose.py +31 -30
- tests/preprocessing/test_frequency_filters.py +17 -32
- tests/preprocessing/test_preprocessor.py +0 -19
- tests/preprocessing/test_utils.py +13 -1
- tests/test_analyzer.py +2 -10
- tests/test_backends.py +47 -18
- tests/test_density.py +72 -13
- tests/test_extensions.py +1 -0
- tests/test_matching_cli.py +23 -9
- tests/test_matching_exhaustive.py +5 -5
- tests/test_matching_utils.py +3 -3
- tests/test_orientations.py +12 -0
- tests/test_rotations.py +13 -23
- tests/test_structure.py +1 -7
- tme/__version__.py +1 -1
- tme/analyzer/aggregation.py +47 -16
- tme/analyzer/base.py +34 -0
- tme/analyzer/peaks.py +26 -13
- tme/analyzer/proxy.py +14 -0
- tme/backends/_jax_utils.py +91 -68
- tme/backends/cupy_backend.py +6 -19
- tme/backends/jax_backend.py +103 -98
- tme/backends/matching_backend.py +0 -17
- tme/backends/mlx_backend.py +0 -29
- tme/backends/npfftw_backend.py +100 -97
- tme/backends/pytorch_backend.py +65 -78
- tme/cli.py +2 -2
- tme/density.py +44 -57
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/_utils.py +52 -24
- tme/filters/bandpass.py +99 -105
- tme/filters/compose.py +133 -39
- tme/filters/ctf.py +51 -102
- tme/filters/reconstruction.py +67 -122
- tme/filters/wedge.py +296 -325
- tme/filters/whitening.py +39 -75
- tme/mask.py +2 -2
- tme/matching_data.py +87 -15
- tme/matching_exhaustive.py +70 -120
- tme/matching_optimization.py +9 -63
- tme/matching_scores.py +261 -100
- tme/matching_utils.py +150 -91
- tme/memory.py +1 -0
- tme/orientations.py +17 -3
- tme/preprocessor.py +0 -239
- tme/rotations.py +102 -70
- tme/structure.py +601 -631
- tme/types.py +1 -0
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/WHEEL +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/entry_points.txt +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/top_level.txt +0 -0
tme/filters/bandpass.py
CHANGED
@@ -1,157 +1,157 @@
|
|
1
1
|
"""
|
2
|
-
Implements class
|
2
|
+
Implements class BandPass and BandPassReconstructed.
|
3
3
|
|
4
4
|
Copyright (c) 2024 European Molecular Biology Laboratory
|
5
5
|
|
6
6
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
7
|
"""
|
8
8
|
|
9
|
-
from typing import Tuple
|
10
9
|
from math import log, sqrt
|
11
|
-
from
|
10
|
+
from typing import Tuple, Union, Optional
|
11
|
+
from pydantic.dataclasses import dataclass
|
12
12
|
|
13
13
|
import numpy as np
|
14
14
|
|
15
15
|
from ..types import BackendArray
|
16
16
|
from ..backends import backend as be
|
17
17
|
from .compose import ComposableFilter
|
18
|
-
from ._utils import
|
19
|
-
crop_real_fourier,
|
20
|
-
shift_fourier,
|
21
|
-
pad_to_length,
|
22
|
-
frequency_grid_at_angle,
|
23
|
-
fftfreqn,
|
24
|
-
)
|
18
|
+
from ._utils import pad_to_length, frequency_grid_at_angle, fftfreqn
|
25
19
|
|
26
20
|
__all__ = ["BandPass", "BandPassReconstructed"]
|
27
21
|
|
28
22
|
|
29
|
-
@dataclass
|
23
|
+
@dataclass(config=dict(extra="allow"))
|
30
24
|
class BandPass(ComposableFilter):
|
31
25
|
"""
|
32
|
-
Generate per-
|
26
|
+
Generate per-tilt Bandpass filter.
|
27
|
+
|
28
|
+
Examples
|
29
|
+
--------
|
30
|
+
|
31
|
+
The following creates an instance of :py:class:`BandPass`
|
32
|
+
|
33
|
+
>>> from tme.filters import BandPass
|
34
|
+
>>> bpf_instance = BandPass(
|
35
|
+
>>> angles=(-70, 0, 30),
|
36
|
+
>>> lowpass=30,
|
37
|
+
>>> sampling_rate=10
|
38
|
+
>>> )
|
39
|
+
|
40
|
+
Differently from :py:class:`tme.filters.BandPassReconstructed`, the filter
|
41
|
+
masks are intended to be used in subsequent reconstruction using
|
42
|
+
:py:class:`tme.filters.ReconstructFromTilt`.
|
43
|
+
|
44
|
+
The ``opening_axis``, ``tilt_axis`` and ``angles`` parameter are used
|
45
|
+
to determine the correct frequencies for non-cubical input shapes. The
|
46
|
+
``shape`` argument contains the shape of the reconstruction.
|
47
|
+
|
48
|
+
>>> ret = bpf_instance(shape=(50,50,25))
|
49
|
+
>>> mask = ret["data"]
|
50
|
+
>>> mask.shape # 3, 50, 50
|
51
|
+
|
52
|
+
Note that different from its reconstructed counterpart, the DC
|
53
|
+
component is at the center of the array.
|
54
|
+
|
55
|
+
>>> import matplotlib.pyplot as plt
|
56
|
+
>>> fix, ax = plt.subplots(nrows=1, ncols=3)
|
57
|
+
>>> _ = [ax[i].imshow(mask[i]) for i in range(mask.shape[0])]
|
58
|
+
>>> plt.show()
|
59
|
+
|
33
60
|
"""
|
34
61
|
|
35
|
-
#: The
|
36
|
-
|
37
|
-
#: The
|
38
|
-
|
39
|
-
#: The
|
40
|
-
|
41
|
-
#:
|
42
|
-
|
62
|
+
#: The lowpass cutoff, defaults to None.
|
63
|
+
lowpass: Optional[float] = None
|
64
|
+
#: The highpass cutoff, defaults to None.
|
65
|
+
highpass: Optional[float] = None
|
66
|
+
#: The sampling rate, defaults to 1 Ångstrom / voxel.
|
67
|
+
sampling_rate: Union[Tuple[float, ...], float] = 1
|
68
|
+
#: Whether to use Gaussian bandpass filter, defaults to True.
|
69
|
+
use_gaussian: bool = True
|
70
|
+
#: The tilt angles in degrees.
|
71
|
+
angles: Tuple[float, ...] = None
|
43
72
|
#: Axis the plane is tilted over, defaults to 0 (x).
|
44
73
|
tilt_axis: int = 0
|
45
74
|
#: The projection axis, defaults to 2 (z).
|
46
75
|
opening_axis: int = 2
|
47
|
-
#: The sampling rate, defaults to 1 Ångstrom / voxel.
|
48
|
-
sampling_rate: Tuple[float] = 1
|
49
|
-
#: Whether to use Gaussian bandpass filter, defaults to True.
|
50
|
-
use_gaussian: bool = True
|
51
|
-
#: Whether to return a mask for rfft
|
52
|
-
return_real_fourier: bool = False
|
53
76
|
|
54
|
-
def
|
77
|
+
def _evaluate(self, shape: Tuple[int, ...], **kwargs):
|
55
78
|
"""
|
56
79
|
Returns a Bandpass stack of chosen parameters with DC component in the center.
|
57
80
|
"""
|
58
|
-
func_args = vars(self).copy()
|
59
|
-
func_args.update(kwargs)
|
60
|
-
|
61
81
|
func = discrete_bandpass
|
62
|
-
if
|
82
|
+
if kwargs.get("use_gaussian"):
|
63
83
|
func = gaussian_bandpass
|
64
84
|
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
return_real_fourier = False
|
69
|
-
|
70
|
-
angles = np.atleast_1d(func_args["angles"])
|
71
|
-
_lowpass = pad_to_length(func_args["lowpass"], angles.size)
|
72
|
-
_highpass = pad_to_length(func_args["highpass"], angles.size)
|
85
|
+
angles = np.atleast_1d(kwargs["angles"])
|
86
|
+
_lowpass = pad_to_length(kwargs["lowpass"], angles.size)
|
87
|
+
_highpass = pad_to_length(kwargs["highpass"], angles.size)
|
73
88
|
|
74
89
|
masks = []
|
75
90
|
for index, angle in enumerate(angles):
|
76
91
|
frequency_grid = frequency_grid_at_angle(
|
77
|
-
shape=
|
78
|
-
tilt_axis=
|
79
|
-
opening_axis=
|
92
|
+
shape=shape,
|
93
|
+
tilt_axis=kwargs["tilt_axis"],
|
94
|
+
opening_axis=kwargs["opening_axis"],
|
80
95
|
angle=angle,
|
81
96
|
sampling_rate=1,
|
82
97
|
)
|
83
|
-
|
84
|
-
|
85
|
-
mask = func(grid=frequency_grid, **
|
86
|
-
|
87
|
-
|
88
|
-
if return_real_fourier:
|
89
|
-
mask = crop_real_fourier(mask)
|
90
|
-
masks.append(mask[None])
|
91
|
-
|
92
|
-
masks = be.concatenate(masks, axis=0)
|
93
|
-
return {
|
94
|
-
"data": be.to_backend_array(masks),
|
95
|
-
"shape": func_args["shape"],
|
96
|
-
"return_real_fourier": return_real_fourier,
|
97
|
-
"is_multiplicative_filter": True,
|
98
|
-
}
|
98
|
+
kwargs["lowpass"] = _lowpass[index]
|
99
|
+
kwargs["highpass"] = _highpass[index]
|
100
|
+
mask = func(grid=frequency_grid, **kwargs)
|
101
|
+
masks.append(be.to_backend_array(mask[None]))
|
102
|
+
return {"data": be.concatenate(masks, axis=0), "shape": shape}
|
99
103
|
|
100
104
|
|
101
|
-
@dataclass
|
105
|
+
@dataclass(config=dict(extra="allow"))
|
102
106
|
class BandPassReconstructed(ComposableFilter):
|
103
107
|
"""
|
104
|
-
Generate
|
108
|
+
Generate Bandpass filter for reconstructions.
|
109
|
+
|
110
|
+
Examples
|
111
|
+
--------
|
112
|
+
The following creates an instance of :py:class:`BandPassReconstructed`
|
113
|
+
|
114
|
+
>>> from tme.filters import BandPassReconstructed
|
115
|
+
>>> bpf_instance = BandPassReconstructed(
|
116
|
+
>>> lowpass=30,
|
117
|
+
>>> sampling_rate=10
|
118
|
+
>>> )
|
119
|
+
|
120
|
+
We can use its call method to create filters of given shape
|
121
|
+
|
122
|
+
>>> import matplotlib.pyplot as plt
|
123
|
+
>>> ret = bpf_instance(shape=(50,50))
|
124
|
+
|
125
|
+
The ``data`` key of the returned dictionary contains the corresponding
|
126
|
+
Fourier filter mask. The DC component is located at the origin.
|
127
|
+
|
128
|
+
>>> plt.imshow(ret["data"])
|
129
|
+
>>> plt.show()
|
105
130
|
"""
|
106
131
|
|
107
132
|
#: The lowpass cutoff, defaults to None.
|
108
|
-
lowpass: float = None
|
133
|
+
lowpass: Optional[float] = None
|
109
134
|
#: The highpass cutoff, defaults to None.
|
110
|
-
highpass: float = None
|
111
|
-
#: The shape of the to-be created mask.
|
112
|
-
shape: Tuple[int] = None
|
113
|
-
#: Axis the plane is tilted over, defaults to 0 (x).
|
114
|
-
tilt_axis: int = 0
|
115
|
-
#: The projection axis, defaults to 2 (z).
|
116
|
-
opening_axis: int = 2
|
135
|
+
highpass: Optional[float] = None
|
117
136
|
#: The sampling rate, defaults to 1 Ångstrom / voxel.
|
118
|
-
sampling_rate: Tuple[float] = 1
|
137
|
+
sampling_rate: Union[Tuple[float, ...], float] = 1
|
119
138
|
#: Whether to use Gaussian bandpass filter, defaults to True.
|
120
139
|
use_gaussian: bool = True
|
121
|
-
#: Whether to return a mask for rfft
|
122
|
-
return_real_fourier: bool = False
|
123
|
-
|
124
|
-
def __call__(self, **kwargs):
|
125
|
-
func_args = vars(self).copy()
|
126
|
-
func_args.update(kwargs)
|
127
140
|
|
141
|
+
def _evaluate(self, shape: Tuple[int, ...], **kwargs):
|
128
142
|
func = discrete_bandpass
|
129
|
-
if
|
143
|
+
if kwargs.get("use_gaussian"):
|
130
144
|
func = gaussian_bandpass
|
131
145
|
|
132
|
-
return_real_fourier = func_args.get("return_real_fourier", True)
|
133
|
-
shape_is_real_fourier = func_args.get("shape_is_real_fourier", False)
|
134
|
-
if shape_is_real_fourier:
|
135
|
-
return_real_fourier = False
|
136
|
-
|
137
146
|
grid = fftfreqn(
|
138
|
-
shape=
|
147
|
+
shape=shape,
|
139
148
|
sampling_rate=0.5,
|
140
|
-
shape_is_real_fourier=
|
149
|
+
shape_is_real_fourier=False,
|
141
150
|
compute_euclidean_norm=True,
|
151
|
+
fftshift=False,
|
142
152
|
)
|
143
|
-
|
144
|
-
|
145
|
-
mask = shift_fourier(data=mask, shape_is_real_fourier=shape_is_real_fourier)
|
146
|
-
if return_real_fourier:
|
147
|
-
mask = crop_real_fourier(mask)
|
148
|
-
|
149
|
-
return {
|
150
|
-
"data": be.to_backend_array(mask),
|
151
|
-
"shape": func_args["shape"],
|
152
|
-
"return_real_fourier": return_real_fourier,
|
153
|
-
"is_multiplicative_filter": True,
|
154
|
-
}
|
153
|
+
ret = be.to_backend_array(func(grid=grid, **kwargs))
|
154
|
+
return {"data": ret, "shape": shape}
|
155
155
|
|
156
156
|
|
157
157
|
def discrete_bandpass(
|
@@ -169,11 +169,9 @@ def discrete_bandpass(
|
|
169
169
|
grid : BackendArray
|
170
170
|
Frequencies in Fourier space.
|
171
171
|
lowpass : float
|
172
|
-
The lowpass cutoff in units of sampling rate.
|
172
|
+
The lowpass cutoff in spatial units of sampling rate.
|
173
173
|
highpass : float
|
174
|
-
The highpass cutoff in units of sampling rate.
|
175
|
-
return_real_fourier : bool, optional
|
176
|
-
Whether to return only the real Fourier space, defaults to False.
|
174
|
+
The highpass cutoff in spatial units of sampling rate.
|
177
175
|
sampling_rate : float
|
178
176
|
The sampling rate in Fourier space.
|
179
177
|
**kwargs : dict
|
@@ -244,16 +242,13 @@ def gaussian_bandpass(
|
|
244
242
|
if has_lowpass:
|
245
243
|
lowpass = upper_sampling / (lowpass * norm)
|
246
244
|
lowpass = be.multiply(2, be.square(lowpass))
|
247
|
-
|
248
|
-
lowpass_filter = be.divide(grid, lowpass, out=grid)
|
249
|
-
else:
|
250
|
-
lowpass_filter = be.divide(grid, lowpass)
|
245
|
+
lowpass_filter = be.divide(grid, lowpass)
|
251
246
|
lowpass_filter = be.exp(lowpass_filter, out=lowpass_filter)
|
252
247
|
|
253
248
|
if has_highpass:
|
254
249
|
highpass = upper_sampling / (highpass * norm)
|
255
250
|
highpass = be.multiply(2, be.square(highpass))
|
256
|
-
highpass_filter = be.divide(grid, highpass
|
251
|
+
highpass_filter = be.divide(grid, highpass)
|
257
252
|
highpass_filter = be.exp(highpass_filter, out=highpass_filter)
|
258
253
|
highpass_filter = be.subtract(1, highpass_filter, out=highpass_filter)
|
259
254
|
|
@@ -267,5 +262,4 @@ def gaussian_bandpass(
|
|
267
262
|
)
|
268
263
|
else:
|
269
264
|
bandpass_filter = be.full(grid.shape, fill_value=1, dtype=be._float_dtype)
|
270
|
-
|
271
265
|
return bandpass_filter
|
tme/filters/compose.py
CHANGED
@@ -9,74 +9,168 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
|
9
9
|
from typing import Tuple, Dict
|
10
10
|
from abc import ABC, abstractmethod
|
11
11
|
|
12
|
-
from
|
12
|
+
from ._utils import crop_real_fourier
|
13
|
+
from ..backends import backend as be
|
13
14
|
|
14
15
|
__all__ = ["Compose", "ComposableFilter"]
|
15
16
|
|
16
17
|
|
17
|
-
class
|
18
|
+
class ComposableFilter(ABC):
|
18
19
|
"""
|
19
|
-
|
20
|
+
Base class for composable filters.
|
20
21
|
|
21
|
-
This class
|
22
|
-
|
22
|
+
This class provides a standard interface for filters used in template matching
|
23
|
+
and reconstruction. It automatically handles:
|
23
24
|
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
25
|
+
- Parameter merging between instance attributes and runtime arguments
|
26
|
+
- Fourier space shifting when needed
|
27
|
+
- Real Fourier transform cropping for efficiency
|
28
|
+
- Standardized result dictionary formatting
|
28
29
|
|
29
|
-
|
30
|
-
|
31
|
-
Dict
|
32
|
-
Metadata resulting from the composed transformations.
|
30
|
+
Subclasses need to implement :py:meth:`ComposableFilter._evaluate` which
|
31
|
+
contains the core filter computation logic.
|
33
32
|
|
33
|
+
By default, all filters are assumed to be multiplicative in Fourier space,
|
34
|
+
which covers the vast majority of use cases (bandpass, CTF, wedge, whitening, etc.).
|
35
|
+
Only explicitly specify non-multiplicative behavior when needed.
|
34
36
|
"""
|
35
37
|
|
36
|
-
|
37
|
-
|
38
|
+
@abstractmethod
|
39
|
+
def _evaluate(self, **kwargs) -> Dict:
|
40
|
+
"""
|
41
|
+
Compute the actual filter given a set of keyword parameters.
|
38
42
|
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
+
Parameters
|
44
|
+
----------
|
45
|
+
**kwargs : dict
|
46
|
+
Merged parameters from instance attributes and runtime arguments
|
47
|
+
passed to :py:meth:`__call__`. This includes both the filter's
|
48
|
+
configuration parameters and any runtime overrides.
|
43
49
|
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
50
|
+
Returns
|
51
|
+
-------
|
52
|
+
Dict
|
53
|
+
Dictionary containing the filter result and metadata. Required keys:
|
48
54
|
|
49
|
-
|
50
|
-
|
55
|
+
- data : BackendArray or array-like
|
56
|
+
The computed filter data
|
57
|
+
- shape : tuple of int
|
58
|
+
Input shape the filter was built for.
|
51
59
|
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
60
|
+
Optional keys:
|
61
|
+
- is_multiplicative_filter : bool
|
62
|
+
Whether the filter is multiplicative in Fourier space (default True)
|
63
|
+
"""
|
56
64
|
|
57
|
-
|
65
|
+
def __call__(self, return_real_fourier: bool = False, **kwargs) -> Dict:
|
66
|
+
"""
|
67
|
+
This method provides the standard interface for creating of composable
|
68
|
+
filter masks. It merges instance attributes with runtime parameters,
|
69
|
+
and ensures Fourier conventions are consistent across filters.
|
58
70
|
|
59
|
-
|
71
|
+
Parameters
|
72
|
+
----------
|
73
|
+
return_real_fourier : bool, optional
|
74
|
+
Whether to crop the filter mask for compatibility with real input
|
75
|
+
FFTs (i.e., :py:func:`numpy.fft.rfft`). When True, only the
|
76
|
+
positive frequency components are returned, reducing memory usage
|
77
|
+
and computation time for real-valued inputs. Default is False.
|
78
|
+
**kwargs : dict
|
79
|
+
Additional keyword arguments passed to :py:meth:`_evaluate`.
|
80
|
+
These will override any matching instance attributes during
|
81
|
+
parameter merging.
|
82
|
+
|
83
|
+
Returns
|
84
|
+
-------
|
85
|
+
Dict
|
86
|
+
- data : BackendArray
|
87
|
+
The processed filter data, converted to the appropriate backend
|
88
|
+
array type and with fourier operations applied as needed
|
89
|
+
- shape : tuple of int or None
|
90
|
+
Shape for which the filter was created
|
91
|
+
- return_real_fourier : bool
|
92
|
+
The value of the return_real_fourier parameter
|
93
|
+
- is_multiplicative_filter : bool
|
94
|
+
Whether the filter is multiplicative in Fourier space
|
95
|
+
- Additional metadata from the filter implementation
|
96
|
+
"""
|
97
|
+
ret = self._evaluate(**(vars(self) | kwargs))
|
60
98
|
|
99
|
+
# This parameter is only here to allow for using Composable filters outside
|
100
|
+
# the context of a Compose operation. Internally, we require return_real_fourier
|
101
|
+
# to be False, e.g., for filters that require reconstruction.
|
102
|
+
if return_real_fourier:
|
103
|
+
ret["data"] = crop_real_fourier(ret["data"])
|
61
104
|
|
62
|
-
|
105
|
+
ret["data"] = be.to_backend_array(ret["data"])
|
106
|
+
ret["return_real_fourier"] = return_real_fourier
|
107
|
+
return ret
|
108
|
+
|
109
|
+
|
110
|
+
class Compose:
|
63
111
|
"""
|
64
|
-
|
112
|
+
Compose a series of filters.
|
113
|
+
|
114
|
+
Parameters
|
115
|
+
----------
|
116
|
+
transforms : tuple of :py:class:`ComposableFilter`.
|
117
|
+
Tuple of filter instances.
|
65
118
|
"""
|
66
119
|
|
67
|
-
|
68
|
-
|
120
|
+
def __init__(self, transforms: Tuple[ComposableFilter, ...]):
|
121
|
+
for transform in transforms:
|
122
|
+
if not isinstance(transform, ComposableFilter):
|
123
|
+
raise ValueError(f"{transform} is not a child of {ComposableFilter}.")
|
124
|
+
|
125
|
+
self.transforms = transforms
|
126
|
+
|
127
|
+
def __call__(self, return_real_fourier: bool = False, **kwargs) -> Dict:
|
69
128
|
"""
|
129
|
+
Apply the sequence of filters in order, chaining their outputs.
|
70
130
|
|
71
131
|
Parameters
|
72
132
|
----------
|
73
|
-
|
74
|
-
|
133
|
+
return_real_fourier : bool, optional
|
134
|
+
Whether to crop the filter mask for compatibility with real input
|
135
|
+
FFTs (i.e., :py:func:`numpy.fft.rfft`). When True, only the
|
136
|
+
positive frequency components are returned, reducing memory usage
|
137
|
+
and computation time for real-valued inputs. Default is False.
|
75
138
|
**kwargs : dict
|
76
|
-
|
139
|
+
Keyword arguments passed to the first filter and propagated through
|
140
|
+
the pipeline.
|
77
141
|
|
78
142
|
Returns
|
79
143
|
-------
|
80
144
|
Dict
|
81
|
-
|
145
|
+
Result dictionary from the final filter in the composition, containing:
|
146
|
+
|
147
|
+
- data : BackendArray
|
148
|
+
The final composite filter data. For multiplicative filters, this is
|
149
|
+
the element-wise product of all individual filter outputs.
|
150
|
+
- shape : tuple of int
|
151
|
+
Shape of the filter data
|
152
|
+
- return_real_fourier : bool
|
153
|
+
Whether the output is compatible with real FFTs
|
154
|
+
- Additional metadata from the filter pipeline
|
82
155
|
"""
|
156
|
+
meta = {}
|
157
|
+
if not len(self.transforms):
|
158
|
+
return meta
|
159
|
+
|
160
|
+
meta = self.transforms[0](**kwargs)
|
161
|
+
for transform in self.transforms[1:]:
|
162
|
+
kwargs.update(meta)
|
163
|
+
ret = transform(**kwargs)
|
164
|
+
|
165
|
+
if "data" not in ret:
|
166
|
+
continue
|
167
|
+
|
168
|
+
if ret.get("is_multiplicative_filter", True):
|
169
|
+
prev_data = meta.pop("data")
|
170
|
+
ret["data"] = be.multiply(ret["data"], prev_data)
|
171
|
+
ret["merge"], prev_data = None, None
|
172
|
+
meta = ret
|
173
|
+
|
174
|
+
if return_real_fourier:
|
175
|
+
meta["data"] = crop_real_fourier(meta["data"])
|
176
|
+
return meta
|