pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/match_template.py +219 -216
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/postprocess.py +86 -54
- pytme-0.2.3.data/scripts/preprocess.py +132 -0
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +181 -94
- pytme-0.2.3.dist-info/METADATA +92 -0
- pytme-0.2.3.dist-info/RECORD +75 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/WHEEL +1 -1
- pytme-0.2.1.data/scripts/preprocess.py → scripts/eval.py +1 -1
- scripts/extract_candidates.py +20 -13
- scripts/match_template.py +219 -216
- scripts/match_template_filters.py +154 -95
- scripts/postprocess.py +86 -54
- scripts/preprocess.py +95 -56
- scripts/preprocessor_gui.py +181 -94
- scripts/refine_matches.py +265 -61
- tme/__init__.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +458 -813
- tme/backends/__init__.py +40 -11
- tme/backends/_jax_utils.py +187 -0
- tme/backends/cupy_backend.py +109 -226
- tme/backends/jax_backend.py +230 -152
- tme/backends/matching_backend.py +445 -384
- tme/backends/mlx_backend.py +32 -59
- tme/backends/npfftw_backend.py +240 -507
- tme/backends/pytorch_backend.py +30 -151
- tme/density.py +248 -371
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +328 -284
- tme/matching_exhaustive.py +195 -1499
- tme/matching_optimization.py +143 -106
- tme/matching_scores.py +887 -0
- tme/matching_utils.py +287 -388
- tme/memory.py +377 -0
- tme/orientations.py +78 -21
- tme/parser.py +3 -4
- tme/preprocessing/_utils.py +61 -32
- tme/preprocessing/composable_filter.py +7 -4
- tme/preprocessing/compose.py +7 -3
- tme/preprocessing/frequency_filters.py +49 -39
- tme/preprocessing/tilt_series.py +44 -72
- tme/preprocessor.py +560 -526
- tme/structure.py +491 -188
- tme/types.py +5 -3
- pytme-0.2.1.dist-info/METADATA +0 -73
- pytme-0.2.1.dist-info/RECORD +0 -73
- tme/helpers.py +0 -881
- tme/matching_constrained.py +0 -195
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
tme/preprocessing/_utils.py
CHANGED
@@ -5,12 +5,13 @@
|
|
5
5
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
6
|
"""
|
7
7
|
|
8
|
-
from typing import Tuple
|
8
|
+
from typing import Tuple, List
|
9
9
|
|
10
10
|
import numpy as np
|
11
|
-
from numpy.typing import NDArray
|
12
11
|
|
13
|
-
from ..backends import backend
|
12
|
+
from ..backends import backend as be
|
13
|
+
from ..backends import NumpyFFTWBackend
|
14
|
+
from ..types import BackendArray, NDArray
|
14
15
|
from ..matching_utils import euler_to_rotationmatrix
|
15
16
|
|
16
17
|
|
@@ -93,18 +94,27 @@ def frequency_grid_at_angle(
|
|
93
94
|
tilt_shape = compute_tilt_shape(
|
94
95
|
shape=shape, opening_axis=opening_axis, reduce_dim=False
|
95
96
|
)
|
96
|
-
|
97
|
+
|
98
|
+
if angle == 0:
|
99
|
+
index_grid = fftfreqn(
|
100
|
+
tuple(x for x in tilt_shape if x != 1),
|
101
|
+
sampling_rate=1,
|
102
|
+
compute_euclidean_norm=True,
|
103
|
+
)
|
104
|
+
|
97
105
|
if angle != 0:
|
98
106
|
angles = np.zeros(len(shape))
|
99
107
|
angles[tilt_axis] = angle
|
100
108
|
rotation_matrix = euler_to_rotationmatrix(np.roll(angles, opening_axis - 1))
|
109
|
+
|
110
|
+
index_grid = fftfreqn(tilt_shape, sampling_rate=None)
|
101
111
|
index_grid = np.einsum("ij,j...->i...", rotation_matrix, index_grid)
|
112
|
+
norm = np.multiply(sampling_rate, shape).astype(int)
|
102
113
|
|
103
|
-
|
114
|
+
index_grid = np.divide(index_grid.T, norm).T
|
115
|
+
index_grid = np.squeeze(index_grid)
|
116
|
+
index_grid = np.linalg.norm(index_grid, axis=(0))
|
104
117
|
|
105
|
-
index_grid = np.multiply(index_grid.T, norm).T
|
106
|
-
index_grid = np.squeeze(index_grid)
|
107
|
-
index_grid = np.linalg.norm(index_grid, axis=(0))
|
108
118
|
return index_grid
|
109
119
|
|
110
120
|
|
@@ -113,9 +123,10 @@ def fftfreqn(
|
|
113
123
|
sampling_rate: Tuple[float],
|
114
124
|
compute_euclidean_norm: bool = False,
|
115
125
|
shape_is_real_fourier: bool = False,
|
126
|
+
return_sparse_grid: bool = False,
|
116
127
|
) -> NDArray:
|
117
128
|
"""
|
118
|
-
Generate the n-dimensional discrete Fourier
|
129
|
+
Generate the n-dimensional discrete Fourier transform sample frequencies.
|
119
130
|
|
120
131
|
Parameters:
|
121
132
|
-----------
|
@@ -133,56 +144,74 @@ def fftfreqn(
|
|
133
144
|
NDArray
|
134
145
|
The sample frequencies.
|
135
146
|
"""
|
136
|
-
|
137
|
-
|
138
|
-
norm =
|
147
|
+
# There is no real need to have these operations on GPU right now
|
148
|
+
temp_backend = NumpyFFTWBackend()
|
149
|
+
norm = temp_backend.full(len(shape), fill_value=1)
|
150
|
+
center = temp_backend.astype(temp_backend.divide(shape, 2), temp_backend._int_dtype)
|
139
151
|
if sampling_rate is not None:
|
140
|
-
norm =
|
152
|
+
norm = temp_backend.astype(temp_backend.multiply(shape, sampling_rate), int)
|
141
153
|
|
142
154
|
if shape_is_real_fourier:
|
143
|
-
center[-1] = 0
|
144
|
-
norm[-1] = 1
|
155
|
+
center[-1], norm[-1] = 0, 1
|
145
156
|
if sampling_rate is not None:
|
146
157
|
norm[-1] = (shape[-1] - 1) * 2 * sampling_rate
|
147
158
|
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
159
|
+
grids = []
|
160
|
+
for i, x in enumerate(shape):
|
161
|
+
baseline_dims = tuple(1 if i != t else x for t in range(len(shape)))
|
162
|
+
grid = (temp_backend.arange(x) - center[i]) / norm[i]
|
163
|
+
grids.append(temp_backend.reshape(grid, baseline_dims))
|
152
164
|
|
153
165
|
if compute_euclidean_norm:
|
154
|
-
|
155
|
-
|
156
|
-
|
166
|
+
grids = sum(temp_backend.square(x) for x in grids)
|
167
|
+
grids = temp_backend.sqrt(grids, out=grids)
|
168
|
+
return grids
|
169
|
+
|
170
|
+
if return_sparse_grid:
|
171
|
+
return grids
|
157
172
|
|
158
|
-
|
173
|
+
grid_flesh = temp_backend.full(shape, fill_value=1)
|
174
|
+
grids = temp_backend.stack(tuple(grid * grid_flesh for grid in grids))
|
159
175
|
|
176
|
+
return grids
|
160
177
|
|
161
|
-
|
178
|
+
|
179
|
+
def crop_real_fourier(data: BackendArray) -> BackendArray:
|
162
180
|
"""
|
163
181
|
Crop the real part of a Fourier transform.
|
164
182
|
|
165
183
|
Parameters:
|
166
184
|
-----------
|
167
|
-
data :
|
185
|
+
data : BackendArray
|
168
186
|
The Fourier transformed data.
|
169
187
|
|
170
188
|
Returns:
|
171
189
|
--------
|
172
|
-
|
190
|
+
BackendArray
|
173
191
|
The cropped data.
|
174
192
|
"""
|
175
193
|
stop = 1 + (data.shape[-1] // 2)
|
176
194
|
return data[..., :stop]
|
177
195
|
|
178
196
|
|
179
|
-
def
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
197
|
+
def compute_fourier_shape(
|
198
|
+
shape: Tuple[int], shape_is_real_fourier: bool = False
|
199
|
+
) -> List[int]:
|
200
|
+
if shape_is_real_fourier:
|
201
|
+
return shape
|
202
|
+
shape = [int(x) for x in shape]
|
203
|
+
shape[-1] = 1 + shape[-1] // 2
|
204
|
+
return shape
|
205
|
+
|
206
|
+
|
207
|
+
def shift_fourier(
|
208
|
+
data: BackendArray, shape_is_real_fourier: bool = False
|
209
|
+
) -> BackendArray:
|
210
|
+
shape = be.to_backend_array(data.shape)
|
211
|
+
shift = be.add(be.divide(shape, 2), be.mod(shape, 2))
|
212
|
+
shift = [int(x) for x in shift]
|
184
213
|
if shape_is_real_fourier:
|
185
214
|
shift[-1] = 0
|
186
215
|
|
187
|
-
data =
|
216
|
+
data = be.roll(data, shift, tuple(i for i in range(len(shift))))
|
188
217
|
return data
|
@@ -17,15 +17,18 @@ class ComposableFilter(ABC):
|
|
17
17
|
@abstractmethod
|
18
18
|
def __call__(self, *args, **kwargs) -> Dict:
|
19
19
|
"""
|
20
|
-
|
21
|
-
|
20
|
+
|
21
|
+
Parameters
|
22
|
+
----------
|
23
|
+
|
22
24
|
*args : tuple
|
23
25
|
Variable length argument list.
|
24
26
|
**kwargs : dict
|
25
27
|
Arbitrary keyword arguments.
|
26
28
|
|
27
|
-
Returns
|
28
|
-
|
29
|
+
Returns
|
30
|
+
-------
|
31
|
+
|
29
32
|
Dict
|
30
33
|
A dictionary representing the result of the filtering operation.
|
31
34
|
"""
|
tme/preprocessing/compose.py
CHANGED
@@ -7,7 +7,7 @@
|
|
7
7
|
|
8
8
|
from typing import Tuple, Dict
|
9
9
|
|
10
|
-
from tme.backends import backend
|
10
|
+
from tme.backends import backend as be
|
11
11
|
|
12
12
|
|
13
13
|
class Compose:
|
@@ -42,9 +42,13 @@ class Compose:
|
|
42
42
|
kwargs.update(meta)
|
43
43
|
ret = transform(**kwargs)
|
44
44
|
|
45
|
+
if "data" not in ret:
|
46
|
+
continue
|
47
|
+
|
45
48
|
if ret.get("is_multiplicative_filter", False):
|
46
|
-
|
47
|
-
ret["
|
49
|
+
prev_data = meta.pop("data")
|
50
|
+
ret["data"] = be.multiply(ret["data"], prev_data, out=ret["data"])
|
51
|
+
ret["merge"], prev_data = None, None
|
48
52
|
|
49
53
|
meta = ret
|
50
54
|
|
@@ -8,12 +8,12 @@ from math import log, sqrt
|
|
8
8
|
from typing import Tuple, Dict
|
9
9
|
|
10
10
|
import numpy as np
|
11
|
-
from numpy.typing import NDArray
|
12
11
|
from scipy.ndimage import mean as ndimean
|
13
12
|
from scipy.ndimage import map_coordinates
|
14
13
|
|
15
|
-
from
|
16
|
-
from ..backends import backend
|
14
|
+
from ..types import BackendArray
|
15
|
+
from ..backends import backend as be
|
16
|
+
from ._utils import fftfreqn, crop_real_fourier, shift_fourier, compute_fourier_shape
|
17
17
|
|
18
18
|
|
19
19
|
class BandPassFilter:
|
@@ -29,7 +29,7 @@ class BandPassFilter:
|
|
29
29
|
highpass : float, optional
|
30
30
|
The highpass cutoff, defaults to None.
|
31
31
|
sampling_rate : Tuple[float], optional
|
32
|
-
The sampling
|
32
|
+
The sampling r_position_to_molmapate in Fourier space, defaults to 1.
|
33
33
|
use_gaussian : bool, optional
|
34
34
|
Whether to use Gaussian bandpass filter, defaults to True.
|
35
35
|
return_real_fourier : bool, optional
|
@@ -63,7 +63,7 @@ class BandPassFilter:
|
|
63
63
|
return_real_fourier: bool = False,
|
64
64
|
shape_is_real_fourier: bool = False,
|
65
65
|
**kwargs,
|
66
|
-
) ->
|
66
|
+
) -> BackendArray:
|
67
67
|
"""
|
68
68
|
Generate a bandpass filter using discrete frequency cutoffs.
|
69
69
|
|
@@ -86,7 +86,7 @@ class BandPassFilter:
|
|
86
86
|
|
87
87
|
Returns:
|
88
88
|
--------
|
89
|
-
|
89
|
+
BackendArray
|
90
90
|
The bandpass filter in Fourier space.
|
91
91
|
"""
|
92
92
|
if shape_is_real_fourier:
|
@@ -127,7 +127,7 @@ class BandPassFilter:
|
|
127
127
|
return_real_fourier: bool = False,
|
128
128
|
shape_is_real_fourier: bool = False,
|
129
129
|
**kwargs,
|
130
|
-
) ->
|
130
|
+
) -> BackendArray:
|
131
131
|
"""
|
132
132
|
Generate a bandpass filter using Gaussian functions.
|
133
133
|
|
@@ -150,7 +150,7 @@ class BandPassFilter:
|
|
150
150
|
|
151
151
|
Returns:
|
152
152
|
--------
|
153
|
-
|
153
|
+
BackendArray
|
154
154
|
The bandpass filter in Fourier space.
|
155
155
|
"""
|
156
156
|
if shape_is_real_fourier:
|
@@ -162,29 +162,32 @@ class BandPassFilter:
|
|
162
162
|
shape_is_real_fourier=shape_is_real_fourier,
|
163
163
|
compute_euclidean_norm=True,
|
164
164
|
)
|
165
|
-
grid =
|
165
|
+
grid = be.to_backend_array(grid)
|
166
|
+
grid = -be.square(grid)
|
166
167
|
|
167
168
|
lowpass_filter, highpass_filter = 1, 1
|
168
169
|
norm = float(sqrt(2 * log(2)))
|
169
|
-
upper_sampling = float(
|
170
|
+
upper_sampling = float(
|
171
|
+
be.max(be.multiply(2, be.to_backend_array(sampling_rate)))
|
172
|
+
)
|
170
173
|
|
171
174
|
if lowpass is not None:
|
172
175
|
lowpass = float(lowpass)
|
173
|
-
lowpass =
|
176
|
+
lowpass = be.maximum(lowpass, be.eps(be._float_dtype))
|
174
177
|
if highpass is not None:
|
175
178
|
highpass = float(highpass)
|
176
|
-
highpass =
|
179
|
+
highpass = be.maximum(highpass, be.eps(be._float_dtype))
|
177
180
|
|
178
181
|
if lowpass is not None:
|
179
182
|
lowpass = upper_sampling / (lowpass * norm)
|
180
|
-
lowpass =
|
181
|
-
lowpass_filter =
|
183
|
+
lowpass = be.multiply(2, be.square(lowpass))
|
184
|
+
lowpass_filter = be.exp(be.divide(grid, lowpass))
|
182
185
|
if highpass is not None:
|
183
186
|
highpass = upper_sampling / (highpass * norm)
|
184
|
-
highpass =
|
185
|
-
highpass_filter = 1 -
|
187
|
+
highpass = be.multiply(2, be.square(highpass))
|
188
|
+
highpass_filter = 1 - be.exp(be.divide(grid, highpass))
|
186
189
|
|
187
|
-
bandpass_filter =
|
190
|
+
bandpass_filter = be.multiply(lowpass_filter, highpass_filter)
|
188
191
|
bandpass_filter = shift_fourier(
|
189
192
|
data=bandpass_filter, shape_is_real_fourier=shape_is_real_fourier
|
190
193
|
)
|
@@ -205,7 +208,7 @@ class BandPassFilter:
|
|
205
208
|
mask = func(**func_args)
|
206
209
|
|
207
210
|
return {
|
208
|
-
"data":
|
211
|
+
"data": be.to_backend_array(mask),
|
209
212
|
"sampling_rate": func_args.get("sampling_rate", 1),
|
210
213
|
"is_multiplicative_filter": True,
|
211
214
|
}
|
@@ -237,14 +240,14 @@ class LinearWhiteningFilter:
|
|
237
240
|
|
238
241
|
@staticmethod
|
239
242
|
def _compute_spectrum(
|
240
|
-
data_rfft:
|
241
|
-
) -> Tuple[
|
243
|
+
data_rfft: BackendArray, n_bins: int = None, batch_dimension: int = None
|
244
|
+
) -> Tuple[BackendArray, BackendArray]:
|
242
245
|
"""
|
243
246
|
Compute the spectrum of the input data.
|
244
247
|
|
245
248
|
Parameters:
|
246
249
|
-----------
|
247
|
-
data_rfft :
|
250
|
+
data_rfft : BackendArray
|
248
251
|
The Fourier transform of the input data.
|
249
252
|
n_bins : int, optional
|
250
253
|
The number of bins for computing the spectrum, defaults to None.
|
@@ -253,9 +256,9 @@ class LinearWhiteningFilter:
|
|
253
256
|
|
254
257
|
Returns:
|
255
258
|
--------
|
256
|
-
bins :
|
259
|
+
bins : BackendArray
|
257
260
|
Array containing the bin indices for the spectrum.
|
258
|
-
radial_averages :
|
261
|
+
radial_averages : BackendArray
|
259
262
|
Array containing the radial averages of the spectrum.
|
260
263
|
"""
|
261
264
|
shape = tuple(x for i, x in enumerate(data_rfft.shape) if i != batch_dimension)
|
@@ -270,7 +273,7 @@ class LinearWhiteningFilter:
|
|
270
273
|
shape_is_real_fourier=True,
|
271
274
|
compute_euclidean_norm=True,
|
272
275
|
)
|
273
|
-
bins =
|
276
|
+
bins = be.to_numpy_array(bins)
|
274
277
|
|
275
278
|
# Implicit lowpass to nyquist
|
276
279
|
bins = np.floor(bins * (n_bins - 1) + 0.5).astype(int)
|
@@ -292,11 +295,11 @@ class LinearWhiteningFilter:
|
|
292
295
|
|
293
296
|
@staticmethod
|
294
297
|
def _interpolate_spectrum(
|
295
|
-
spectrum:
|
298
|
+
spectrum: BackendArray,
|
296
299
|
shape: Tuple[int],
|
297
300
|
shape_is_real_fourier: bool = True,
|
298
301
|
order: int = 1,
|
299
|
-
) ->
|
302
|
+
) -> BackendArray:
|
300
303
|
"""
|
301
304
|
References
|
302
305
|
----------
|
@@ -306,19 +309,19 @@ class LinearWhiteningFilter:
|
|
306
309
|
"""
|
307
310
|
grid = fftfreqn(
|
308
311
|
shape=shape,
|
309
|
-
sampling_rate
|
312
|
+
sampling_rate=0.5,
|
310
313
|
shape_is_real_fourier=shape_is_real_fourier,
|
311
314
|
compute_euclidean_norm=True,
|
312
315
|
)
|
313
|
-
grid =
|
314
|
-
np.multiply(grid, (spectrum.shape[0] - 1), out
|
316
|
+
grid = be.to_numpy_array(grid)
|
317
|
+
np.multiply(grid, (spectrum.shape[0] - 1), out=grid) + 0.5
|
315
318
|
spectrum = map_coordinates(spectrum, grid.reshape(1, -1), order=order)
|
316
319
|
return spectrum.reshape(grid.shape)
|
317
320
|
|
318
321
|
def __call__(
|
319
322
|
self,
|
320
|
-
data:
|
321
|
-
data_rfft:
|
323
|
+
data: BackendArray = None,
|
324
|
+
data_rfft: BackendArray = None,
|
322
325
|
n_bins: int = None,
|
323
326
|
batch_dimension: int = None,
|
324
327
|
order: int = 1,
|
@@ -329,9 +332,9 @@ class LinearWhiteningFilter:
|
|
329
332
|
|
330
333
|
Parameters:
|
331
334
|
-----------
|
332
|
-
data :
|
335
|
+
data : BackendArray, optional
|
333
336
|
The input data, defaults to None.
|
334
|
-
data_rfft :
|
337
|
+
data_rfft : BackendArray, optional
|
335
338
|
The Fourier transform of the input data, defaults to None.
|
336
339
|
n_bins : int, optional
|
337
340
|
The number of bins for computing the spectrum, defaults to None.
|
@@ -348,9 +351,9 @@ class LinearWhiteningFilter:
|
|
348
351
|
Filter data and associated parameters.
|
349
352
|
"""
|
350
353
|
if data_rfft is None:
|
351
|
-
data_rfft = np.fft.rfftn(
|
354
|
+
data_rfft = np.fft.rfftn(be.to_numpy_array(data))
|
352
355
|
|
353
|
-
data_rfft =
|
356
|
+
data_rfft = be.to_numpy_array(data_rfft)
|
354
357
|
|
355
358
|
bins, radial_averages = self._compute_spectrum(
|
356
359
|
data_rfft, n_bins, batch_dimension
|
@@ -358,21 +361,28 @@ class LinearWhiteningFilter:
|
|
358
361
|
|
359
362
|
if order is None:
|
360
363
|
cutoff = bins < radial_averages.size
|
361
|
-
filter_mask = np.zeros(
|
364
|
+
filter_mask = np.zeros(bins.shape, radial_averages.dtype)
|
362
365
|
filter_mask[cutoff] = radial_averages[bins[cutoff]]
|
363
366
|
else:
|
367
|
+
shape = bins.shape
|
368
|
+
if kwargs.get("shape", False):
|
369
|
+
shape = compute_fourier_shape(
|
370
|
+
shape=kwargs.get("shape"),
|
371
|
+
shape_is_real_fourier=kwargs.get("shape_is_real_fourier", False),
|
372
|
+
)
|
373
|
+
|
364
374
|
filter_mask = self._interpolate_spectrum(
|
365
375
|
spectrum=radial_averages,
|
366
|
-
shape=
|
376
|
+
shape=shape,
|
367
377
|
shape_is_real_fourier=True,
|
368
378
|
)
|
369
379
|
|
370
380
|
filter_mask = np.fft.ifftshift(
|
371
381
|
filter_mask,
|
372
|
-
axes=tuple(i for i in range(data_rfft.ndim - 1) if i != batch_dimension)
|
382
|
+
axes=tuple(i for i in range(data_rfft.ndim - 1) if i != batch_dimension),
|
373
383
|
)
|
374
384
|
|
375
385
|
return {
|
376
|
-
"data":
|
386
|
+
"data": be.to_backend_array(filter_mask),
|
377
387
|
"is_multiplicative_filter": True,
|
378
388
|
}
|