pytme 0.1.8__cp311-cp311-macosx_14_0_arm64.whl → 0.2.0__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.0.data/scripts/match_template.py +1019 -0
- pytme-0.2.0.data/scripts/postprocess.py +570 -0
- {pytme-0.1.8.data → pytme-0.2.0.data}/scripts/preprocessor_gui.py +244 -60
- {pytme-0.1.8.dist-info → pytme-0.2.0.dist-info}/METADATA +3 -1
- pytme-0.2.0.dist-info/RECORD +72 -0
- {pytme-0.1.8.dist-info → pytme-0.2.0.dist-info}/WHEEL +1 -1
- scripts/extract_candidates.py +218 -0
- scripts/match_template.py +459 -218
- pytme-0.1.8.data/scripts/match_template.py → scripts/match_template_filters.py +459 -218
- scripts/postprocess.py +380 -435
- scripts/preprocessor_gui.py +244 -60
- scripts/refine_matches.py +218 -0
- tme/__init__.py +2 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +533 -78
- tme/backends/cupy_backend.py +80 -15
- tme/backends/npfftw_backend.py +35 -6
- tme/backends/pytorch_backend.py +15 -7
- tme/density.py +173 -78
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_constrained.py +195 -0
- tme/matching_data.py +78 -32
- tme/matching_exhaustive.py +369 -221
- tme/matching_memory.py +1 -0
- tme/matching_optimization.py +753 -649
- tme/matching_utils.py +152 -8
- tme/orientations.py +561 -0
- tme/preprocessing/__init__.py +2 -0
- tme/preprocessing/_utils.py +176 -0
- tme/preprocessing/composable_filter.py +30 -0
- tme/preprocessing/compose.py +52 -0
- tme/preprocessing/frequency_filters.py +322 -0
- tme/preprocessing/tilt_series.py +967 -0
- tme/preprocessor.py +35 -25
- tme/structure.py +2 -37
- pytme-0.1.8.data/scripts/postprocess.py +0 -625
- pytme-0.1.8.dist-info/RECORD +0 -61
- {pytme-0.1.8.data → pytme-0.2.0.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.1.8.data → pytme-0.2.0.data}/scripts/preprocess.py +0 -0
- {pytme-0.1.8.dist-info → pytme-0.2.0.dist-info}/LICENSE +0 -0
- {pytme-0.1.8.dist-info → pytme-0.2.0.dist-info}/entry_points.txt +0 -0
- {pytme-0.1.8.dist-info → pytme-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,176 @@
|
|
1
|
+
""" Utilities for the generation of frequency grids.
|
2
|
+
|
3
|
+
Copyright (c) 2024 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
|
8
|
+
from typing import Tuple
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
from numpy.typing import NDArray
|
12
|
+
|
13
|
+
from ..backends import backend
|
14
|
+
from ..matching_utils import euler_to_rotationmatrix
|
15
|
+
|
16
|
+
|
17
|
+
def compute_tilt_shape(shape: Tuple[int], opening_axis: int, reduce_dim: bool = False):
|
18
|
+
"""
|
19
|
+
Given an opening_axis, computes the shape of the remaining dimensions.
|
20
|
+
|
21
|
+
Parameters:
|
22
|
+
-----------
|
23
|
+
shape : Tuple[int]
|
24
|
+
The shape of the input array.
|
25
|
+
opening_axis : int
|
26
|
+
The axis along which the array will be tilted.
|
27
|
+
reduce_dim : bool, optional (default=False)
|
28
|
+
Whether to reduce the dimensionality after tilting.
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
--------
|
32
|
+
Tuple[int]
|
33
|
+
The shape of the array after tilting.
|
34
|
+
"""
|
35
|
+
tilt_shape = tuple(x if i != opening_axis else 1 for i, x in enumerate(shape))
|
36
|
+
if reduce_dim:
|
37
|
+
tilt_shape = tuple(x for i, x in enumerate(shape) if i != opening_axis)
|
38
|
+
|
39
|
+
return tilt_shape
|
40
|
+
|
41
|
+
|
42
|
+
def centered_grid(shape: Tuple[int]) -> NDArray:
|
43
|
+
"""
|
44
|
+
Generate an integer valued grid centered around size // 2
|
45
|
+
|
46
|
+
Parameters:
|
47
|
+
-----------
|
48
|
+
shape : Tuple[int]
|
49
|
+
The shape of the grid.
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
--------
|
53
|
+
NDArray
|
54
|
+
The centered grid.
|
55
|
+
"""
|
56
|
+
index_grid = np.array(
|
57
|
+
np.meshgrid(*[np.arange(size) - size // 2 for size in shape], indexing="ij")
|
58
|
+
)
|
59
|
+
return index_grid
|
60
|
+
|
61
|
+
|
62
|
+
def frequency_grid_at_angle(
|
63
|
+
shape: Tuple[int],
|
64
|
+
angle: float,
|
65
|
+
sampling_rate: Tuple[float],
|
66
|
+
opening_axis: int = None,
|
67
|
+
tilt_axis: int = None,
|
68
|
+
) -> NDArray:
|
69
|
+
"""
|
70
|
+
Generate a frequency grid from 0 to 1/(2 * sampling_rate) in each axis.
|
71
|
+
|
72
|
+
Parameters:
|
73
|
+
-----------
|
74
|
+
shape : Tuple[int]
|
75
|
+
The shape of the grid.
|
76
|
+
angle : float
|
77
|
+
The angle at which to generate the grid.
|
78
|
+
sampling_rate : Tuple[float]
|
79
|
+
The sampling rate for each dimension.
|
80
|
+
opening_axis : int, optional
|
81
|
+
The axis to be opened, defaults to None.
|
82
|
+
tilt_axis : int, optional
|
83
|
+
The axis along which the grid is tilted, defaults to None.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
--------
|
87
|
+
NDArray
|
88
|
+
The frequency grid.
|
89
|
+
"""
|
90
|
+
sampling_rate = np.array(sampling_rate)
|
91
|
+
sampling_rate = np.repeat(sampling_rate, len(shape) // sampling_rate.size)
|
92
|
+
|
93
|
+
tilt_shape = compute_tilt_shape(
|
94
|
+
shape=shape, opening_axis=opening_axis, reduce_dim=False
|
95
|
+
)
|
96
|
+
index_grid = centered_grid(shape=tilt_shape)
|
97
|
+
if angle != 0:
|
98
|
+
angles = np.zeros(len(shape))
|
99
|
+
angles[tilt_axis] = angle
|
100
|
+
rotation_matrix = euler_to_rotationmatrix(np.roll(angles, opening_axis - 1))
|
101
|
+
index_grid = np.einsum("ij,j...->i...", rotation_matrix, index_grid)
|
102
|
+
|
103
|
+
norm = np.divide(1, 2 * sampling_rate * np.divide(shape, 2).astype(int))
|
104
|
+
|
105
|
+
index_grid = np.multiply(index_grid.T, norm).T
|
106
|
+
index_grid = np.squeeze(index_grid)
|
107
|
+
index_grid = np.linalg.norm(index_grid, axis=(0))
|
108
|
+
return index_grid
|
109
|
+
|
110
|
+
|
111
|
+
def fftfreqn(
|
112
|
+
shape: Tuple[int],
|
113
|
+
sampling_rate: Tuple[float],
|
114
|
+
compute_euclidean_norm: bool = False,
|
115
|
+
shape_is_real_fourier: bool = False,
|
116
|
+
) -> NDArray:
|
117
|
+
"""
|
118
|
+
Generate the n-dimensional discrete Fourier Transform sample frequencies.
|
119
|
+
|
120
|
+
Parameters:
|
121
|
+
-----------
|
122
|
+
shape : Tuple[int]
|
123
|
+
The shape of the data.
|
124
|
+
sampling_rate : float or Tuple[float]
|
125
|
+
The sampling rate.
|
126
|
+
compute_euclidean_norm : bool, optional
|
127
|
+
Whether to compute the Euclidean norm, defaults to False.
|
128
|
+
shape_is_real_fourier : bool, optional
|
129
|
+
Whether the shape corresponds to a real Fourier transform, defaults to False.
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
--------
|
133
|
+
NDArray
|
134
|
+
The sample frequencies.
|
135
|
+
"""
|
136
|
+
center = backend.astype(backend.divide(shape, 2), backend._default_dtype_int)
|
137
|
+
|
138
|
+
norm = np.ones(3)
|
139
|
+
if sampling_rate is not None:
|
140
|
+
norm = backend.multiply(shape, sampling_rate).astype(int)
|
141
|
+
|
142
|
+
if shape_is_real_fourier:
|
143
|
+
center[-1] = 0
|
144
|
+
norm[-1] = 1
|
145
|
+
if sampling_rate is not None:
|
146
|
+
norm[-1] = (shape[-1] - 1) * 2 * sampling_rate
|
147
|
+
|
148
|
+
indices = backend.transpose(backend.indices(shape))
|
149
|
+
indices -= center
|
150
|
+
indices = backend.divide(indices, norm)
|
151
|
+
indices = backend.transpose(indices)
|
152
|
+
|
153
|
+
if compute_euclidean_norm:
|
154
|
+
backend.square(indices, indices)
|
155
|
+
indices = backend.sum(indices, axis=0)
|
156
|
+
indices = backend.sqrt(indices)
|
157
|
+
|
158
|
+
return indices
|
159
|
+
|
160
|
+
|
161
|
+
def crop_real_fourier(data: NDArray) -> NDArray:
|
162
|
+
"""
|
163
|
+
Crop the real part of a Fourier transform.
|
164
|
+
|
165
|
+
Parameters:
|
166
|
+
-----------
|
167
|
+
data : NDArray
|
168
|
+
The Fourier transformed data.
|
169
|
+
|
170
|
+
Returns:
|
171
|
+
--------
|
172
|
+
NDArray
|
173
|
+
The cropped data.
|
174
|
+
"""
|
175
|
+
stop = 1 + (data.shape[-1] // 2)
|
176
|
+
return data[..., :stop]
|
@@ -0,0 +1,30 @@
|
|
1
|
+
""" Defines a specification for filters that can be used with
|
2
|
+
:py:class:`tme.preprocessing.compose.Compose`.
|
3
|
+
|
4
|
+
Copyright (c) 2024 European Molecular Biology Laboratory
|
5
|
+
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
|
+
"""
|
8
|
+
from typing import Dict
|
9
|
+
from abc import ABC, abstractmethod
|
10
|
+
|
11
|
+
class ComposableFilter(ABC):
|
12
|
+
"""
|
13
|
+
Strategy class for composable filters.
|
14
|
+
"""
|
15
|
+
|
16
|
+
@abstractmethod
|
17
|
+
def __call__(self, *args, **kwargs) -> Dict:
|
18
|
+
"""
|
19
|
+
Parameters:
|
20
|
+
-----------
|
21
|
+
*args : tuple
|
22
|
+
Variable length argument list.
|
23
|
+
**kwargs : dict
|
24
|
+
Arbitrary keyword arguments.
|
25
|
+
|
26
|
+
Returns:
|
27
|
+
--------
|
28
|
+
Dict
|
29
|
+
A dictionary representing the result of the filtering operation.
|
30
|
+
"""
|
@@ -0,0 +1,52 @@
|
|
1
|
+
""" Combine filters using an interface analogous to pytorch's Compose.
|
2
|
+
|
3
|
+
Copyright (c) 2024 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
|
8
|
+
from typing import Tuple, Dict
|
9
|
+
|
10
|
+
from tme.backends import backend
|
11
|
+
|
12
|
+
|
13
|
+
class Compose:
|
14
|
+
"""
|
15
|
+
Compose a series of transformations.
|
16
|
+
|
17
|
+
This class allows composing multiple transformations together. Each transformation
|
18
|
+
is expected to be a callable that accepts keyword arguments and returns metadata.
|
19
|
+
|
20
|
+
Parameters:
|
21
|
+
-----------
|
22
|
+
transforms : Tuple[object]
|
23
|
+
A tuple containing transformation objects.
|
24
|
+
|
25
|
+
Returns:
|
26
|
+
--------
|
27
|
+
Dict
|
28
|
+
Metadata resulting from the composed transformations.
|
29
|
+
|
30
|
+
"""
|
31
|
+
|
32
|
+
def __init__(self, transforms: Tuple[object]):
|
33
|
+
self.transforms = transforms
|
34
|
+
|
35
|
+
def __call__(self, **kwargs: Dict) -> Dict:
|
36
|
+
meta = {}
|
37
|
+
if not len(self.transforms):
|
38
|
+
return meta
|
39
|
+
|
40
|
+
meta = self.transforms[0](**kwargs)
|
41
|
+
for transform in self.transforms[1:]:
|
42
|
+
|
43
|
+
kwargs.update(meta)
|
44
|
+
ret = transform(**kwargs)
|
45
|
+
|
46
|
+
if ret.get("is_multiplicative_filter", False):
|
47
|
+
backend.multiply(ret["data"], meta["data"], ret["data"])
|
48
|
+
ret["merge"] = None
|
49
|
+
|
50
|
+
meta = ret
|
51
|
+
|
52
|
+
return meta
|
@@ -0,0 +1,322 @@
|
|
1
|
+
""" Defines Fourier frequency filters.
|
2
|
+
|
3
|
+
Copyright (c) 2024 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
from typing import Tuple, Dict
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
from numpy.typing import NDArray
|
11
|
+
from scipy.ndimage import mean as ndimean
|
12
|
+
|
13
|
+
from ._utils import fftfreqn, crop_real_fourier
|
14
|
+
from ..backends import backend
|
15
|
+
|
16
|
+
|
17
|
+
class BandPassFilter:
|
18
|
+
"""
|
19
|
+
This class provides methods to generate bandpass filters in Fourier space,
|
20
|
+
either by directly specifying the frequency cutoffs (discrete_bandpass) or
|
21
|
+
by using Gaussian functions (gaussian_bandpass).
|
22
|
+
|
23
|
+
Parameters:
|
24
|
+
-----------
|
25
|
+
lowpass : float, optional
|
26
|
+
The lowpass cutoff, defaults to None.
|
27
|
+
highpass : float, optional
|
28
|
+
The highpass cutoff, defaults to None.
|
29
|
+
sampling_rate : Tuple[float], optional
|
30
|
+
The sampling rate in Fourier space, defaults to 1.
|
31
|
+
use_gaussian : bool, optional
|
32
|
+
Whether to use Gaussian bandpass filter, defaults to True.
|
33
|
+
return_real_fourier : bool, optional
|
34
|
+
Whether to return only the real Fourier space, defaults to False.
|
35
|
+
shape_is_real_fourier : bool, optional
|
36
|
+
Whether the shape represents the real Fourier space, defaults to False.
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
lowpass: float = None,
|
42
|
+
highpass: float = None,
|
43
|
+
sampling_rate: Tuple[float] = 1,
|
44
|
+
use_gaussian: bool = True,
|
45
|
+
return_real_fourier: bool = False,
|
46
|
+
shape_is_real_fourier: bool = False,
|
47
|
+
):
|
48
|
+
self.lowpass = lowpass
|
49
|
+
self.highpass = highpass
|
50
|
+
self.use_gaussian = use_gaussian
|
51
|
+
self.return_real_fourier = return_real_fourier
|
52
|
+
self.shape_is_real_fourier = shape_is_real_fourier
|
53
|
+
self.sampling_rate = sampling_rate
|
54
|
+
|
55
|
+
@staticmethod
|
56
|
+
def discrete_bandpass(
|
57
|
+
shape: Tuple[int],
|
58
|
+
lowpass: float,
|
59
|
+
highpass: float,
|
60
|
+
sampling_rate: Tuple[float],
|
61
|
+
return_real_fourier: bool = False,
|
62
|
+
shape_is_real_fourier: bool = False,
|
63
|
+
**kwargs,
|
64
|
+
) -> NDArray:
|
65
|
+
"""
|
66
|
+
Generate a bandpass filter using discrete frequency cutoffs.
|
67
|
+
|
68
|
+
Parameters:
|
69
|
+
-----------
|
70
|
+
shape : tuple of int
|
71
|
+
The shape of the bandpass filter.
|
72
|
+
lowpass : float
|
73
|
+
The lowpass cutoff in units of sampling rate.
|
74
|
+
highpass : float
|
75
|
+
The highpass cutoff in units of sampling rate.
|
76
|
+
return_real_fourier : bool, optional
|
77
|
+
Whether to return only the real Fourier space, defaults to False.
|
78
|
+
sampling_rate : float
|
79
|
+
The sampling rate in Fourier space.
|
80
|
+
shape_is_real_fourier : bool, optional
|
81
|
+
Whether the shape represents the real Fourier space, defaults to False.
|
82
|
+
**kwargs : dict
|
83
|
+
Additional keyword arguments.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
--------
|
87
|
+
NDArray
|
88
|
+
The bandpass filter in Fourier space.
|
89
|
+
"""
|
90
|
+
grid = fftfreqn(
|
91
|
+
shape=shape,
|
92
|
+
sampling_rate=0.5,
|
93
|
+
shape_is_real_fourier=shape_is_real_fourier,
|
94
|
+
compute_euclidean_norm=True,
|
95
|
+
)
|
96
|
+
|
97
|
+
lowpass = 0 if lowpass is None else lowpass
|
98
|
+
highpass = 1e10 if highpass is None else highpass
|
99
|
+
|
100
|
+
highcut = grid.max()
|
101
|
+
if lowpass > 0:
|
102
|
+
highcut = np.max(2 * sampling_rate / lowpass)
|
103
|
+
lowcut = np.max(2 * sampling_rate / highpass)
|
104
|
+
|
105
|
+
bandpass_filter = ((grid <= highcut) & (grid >= lowcut)) * 1.0
|
106
|
+
shift = backend.add(
|
107
|
+
backend.astype(backend.divide(bandpass_filter.shape, 2), int),
|
108
|
+
backend.mod(bandpass_filter.shape, 2),
|
109
|
+
)
|
110
|
+
if shape_is_real_fourier:
|
111
|
+
shift[-1] = 0
|
112
|
+
|
113
|
+
bandpass_filter = backend.roll(
|
114
|
+
bandpass_filter, shift, tuple(i for i in range(len(shift)))
|
115
|
+
)
|
116
|
+
|
117
|
+
if return_real_fourier:
|
118
|
+
bandpass_filter = crop_real_fourier(bandpass_filter)
|
119
|
+
|
120
|
+
return bandpass_filter
|
121
|
+
|
122
|
+
@staticmethod
|
123
|
+
def gaussian_bandpass(
|
124
|
+
shape: Tuple[int],
|
125
|
+
lowpass: float,
|
126
|
+
highpass: float,
|
127
|
+
sampling_rate: float,
|
128
|
+
return_real_fourier: bool = False,
|
129
|
+
shape_is_real_fourier: bool = False,
|
130
|
+
**kwargs,
|
131
|
+
) -> NDArray:
|
132
|
+
"""
|
133
|
+
Generate a bandpass filter using Gaussian functions.
|
134
|
+
|
135
|
+
Parameters:
|
136
|
+
-----------
|
137
|
+
shape : tuple of int
|
138
|
+
The shape of the bandpass filter.
|
139
|
+
lowpass : float
|
140
|
+
The lowpass cutoff in units of sampling rate.
|
141
|
+
highpass : float
|
142
|
+
The highpass cutoff in units of sampling rate.
|
143
|
+
sampling_rate : float
|
144
|
+
The sampling rate in Fourier space.
|
145
|
+
return_real_fourier : bool, optional
|
146
|
+
Whether to return only the real Fourier space, defaults to False.
|
147
|
+
shape_is_real_fourier : bool, optional
|
148
|
+
Whether the shape represents the real Fourier space, defaults to False.
|
149
|
+
**kwargs : dict
|
150
|
+
Additional keyword arguments.
|
151
|
+
|
152
|
+
Returns:
|
153
|
+
--------
|
154
|
+
NDArray
|
155
|
+
The bandpass filter in Fourier space.
|
156
|
+
"""
|
157
|
+
if shape_is_real_fourier:
|
158
|
+
return_real_fourier = False
|
159
|
+
|
160
|
+
grid = fftfreqn(
|
161
|
+
shape=shape,
|
162
|
+
sampling_rate=0.5,
|
163
|
+
shape_is_real_fourier=shape_is_real_fourier,
|
164
|
+
compute_euclidean_norm=True,
|
165
|
+
)
|
166
|
+
grid = -backend.square(grid)
|
167
|
+
|
168
|
+
lowpass_filter, highpass_filter = 1, 1
|
169
|
+
norm = float(backend.sqrt(2 * backend.log(2)))
|
170
|
+
upper_sampling = float(backend.max(backend.multiply(2, sampling_rate)))
|
171
|
+
|
172
|
+
if lowpass is not None:
|
173
|
+
lowpass = float(lowpass)
|
174
|
+
lowpass = backend.maximum(lowpass, backend.eps(lowpass))
|
175
|
+
if highpass is not None:
|
176
|
+
highpass = float(highpass)
|
177
|
+
highpass = backend.maximum(highpass, backend.eps(highpass))
|
178
|
+
|
179
|
+
if lowpass is not None:
|
180
|
+
lowpass = upper_sampling / (lowpass * norm)
|
181
|
+
lowpass = backend.multiply(2, backend.square(lowpass))
|
182
|
+
lowpass_filter = backend.exp(backend.divide(grid, lowpass))
|
183
|
+
if highpass is not None:
|
184
|
+
highpass = upper_sampling / (highpass * norm)
|
185
|
+
highpass = backend.multiply(2, backend.square(highpass))
|
186
|
+
highpass_filter = 1 - backend.exp(backend.divide(grid, highpass))
|
187
|
+
|
188
|
+
lowpass_filter = backend.multiply(lowpass_filter, highpass_filter)
|
189
|
+
shift = backend.add(
|
190
|
+
backend.astype(backend.divide(lowpass_filter.shape, 2), int),
|
191
|
+
backend.mod(lowpass_filter.shape, 2),
|
192
|
+
)
|
193
|
+
if shape_is_real_fourier:
|
194
|
+
shift[-1] = 0
|
195
|
+
|
196
|
+
lowpass_filter = backend.roll(
|
197
|
+
lowpass_filter, shift, tuple(i for i in range(len(shift)))
|
198
|
+
)
|
199
|
+
|
200
|
+
if return_real_fourier:
|
201
|
+
lowpass_filter = crop_real_fourier(lowpass_filter)
|
202
|
+
|
203
|
+
return lowpass_filter
|
204
|
+
|
205
|
+
def __call__(self, **kwargs):
|
206
|
+
func_args = vars(self)
|
207
|
+
func_args.update(kwargs)
|
208
|
+
|
209
|
+
func = self.discrete_bandpass
|
210
|
+
if func_args.get("use_gaussian"):
|
211
|
+
func = self.gaussian_bandpass
|
212
|
+
|
213
|
+
mask = func(**func_args)
|
214
|
+
|
215
|
+
return {
|
216
|
+
"data": backend.to_backend_array(mask),
|
217
|
+
"sampling_rate": func_args.get("sampling_rate", 1),
|
218
|
+
"is_multiplicative_filter": True,
|
219
|
+
}
|
220
|
+
|
221
|
+
|
222
|
+
class LinearWhiteningFilter:
|
223
|
+
"""
|
224
|
+
This class provides methods to compute the spectrum of the input data and
|
225
|
+
apply linear whitening to the Fourier coefficients.
|
226
|
+
|
227
|
+
Parameters:
|
228
|
+
-----------
|
229
|
+
**kwargs : Dict, optional
|
230
|
+
Additional keyword arguments.
|
231
|
+
"""
|
232
|
+
|
233
|
+
def __init__(self, **kwargs):
|
234
|
+
pass
|
235
|
+
|
236
|
+
@staticmethod
|
237
|
+
def _compute_spectrum(
|
238
|
+
data_rfft: NDArray, n_bins: int = None
|
239
|
+
) -> Tuple[NDArray, NDArray]:
|
240
|
+
"""
|
241
|
+
Compute the spectrum of the input data.
|
242
|
+
|
243
|
+
Parameters:
|
244
|
+
-----------
|
245
|
+
data_rfft : NDArray
|
246
|
+
The Fourier transform of the input data.
|
247
|
+
n_bins : int, optional
|
248
|
+
The number of bins for computing the spectrum, defaults to None.
|
249
|
+
|
250
|
+
Returns:
|
251
|
+
--------
|
252
|
+
bins : NDArray
|
253
|
+
Array containing the bin indices for the spectrum.
|
254
|
+
radial_averages : NDArray
|
255
|
+
Array containing the radial averages of the spectrum.
|
256
|
+
"""
|
257
|
+
max_bins = max(max(data_rfft.shape[:-1]) // 2 + 1, data_rfft.shape[-1])
|
258
|
+
n_bins = max_bins if n_bins is None else n_bins
|
259
|
+
n_bins = int(min(n_bins, max_bins))
|
260
|
+
|
261
|
+
grid = fftfreqn(
|
262
|
+
shape=data_rfft.shape,
|
263
|
+
sampling_rate=None,
|
264
|
+
shape_is_real_fourier=True,
|
265
|
+
compute_euclidean_norm=True,
|
266
|
+
)
|
267
|
+
_, bin_edges = np.histogram(grid, bins=n_bins - 1)
|
268
|
+
bins = np.digitize(grid, bins=bin_edges, right=True)
|
269
|
+
|
270
|
+
fft_shift_axes = tuple(range(data_rfft.ndim - 1))
|
271
|
+
fourier_transform = np.fft.fftshift(data_rfft, axes=fft_shift_axes)
|
272
|
+
fourier_spectrum = np.square(np.abs(fourier_transform))
|
273
|
+
radial_averages = ndimean(fourier_spectrum, labels=bins, index=np.unique(bins))
|
274
|
+
|
275
|
+
np.sqrt(radial_averages, out=radial_averages)
|
276
|
+
np.reciprocal(radial_averages, out=radial_averages)
|
277
|
+
np.divide(radial_averages, radial_averages.max(), out=radial_averages)
|
278
|
+
|
279
|
+
return bins, radial_averages
|
280
|
+
|
281
|
+
def __call__(
|
282
|
+
self,
|
283
|
+
data: NDArray = None,
|
284
|
+
data_rfft: NDArray = None,
|
285
|
+
n_bins: int = None,
|
286
|
+
**kwargs: Dict,
|
287
|
+
) -> Dict:
|
288
|
+
"""
|
289
|
+
Apply linear whitening to the data and return the result.
|
290
|
+
|
291
|
+
Parameters:
|
292
|
+
-----------
|
293
|
+
data : NDArray, optional
|
294
|
+
The input data, defaults to None.
|
295
|
+
data_rfft : NDArray, optional
|
296
|
+
The Fourier transform of the input data, defaults to None.
|
297
|
+
n_bins : int, optional
|
298
|
+
The number of bins for computing the spectrum, defaults to None.
|
299
|
+
**kwargs : Dict
|
300
|
+
Additional keyword arguments.
|
301
|
+
|
302
|
+
Returns:
|
303
|
+
--------
|
304
|
+
Dict
|
305
|
+
A dictionary containing the whitened data and information
|
306
|
+
about the filter being a multiplicative filter.
|
307
|
+
"""
|
308
|
+
if data_rfft is None:
|
309
|
+
data_rfft = np.fft.rfftn(backend.to_numpy_array(data))
|
310
|
+
|
311
|
+
data_rfft = backend.to_numpy_array(data_rfft)
|
312
|
+
|
313
|
+
bins, radial_averages = self._compute_spectrum(data_rfft, n_bins)
|
314
|
+
|
315
|
+
radial_averages = np.fft.ifftshift(
|
316
|
+
radial_averages[bins], axes=tuple(range(data_rfft.ndim - 1))
|
317
|
+
)
|
318
|
+
|
319
|
+
return {
|
320
|
+
"data": backend.to_backend_array(radial_averages),
|
321
|
+
"is_multiplicative_filter": True,
|
322
|
+
}
|