pytme 0.3.1.post2__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.2.data/scripts/estimate_ram_usage.py +97 -0
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/match_template.py +213 -196
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/postprocess.py +40 -78
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/preprocess.py +4 -5
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/preprocessor_gui.py +49 -103
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/pytme_runner.py +46 -69
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/METADATA +3 -2
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/RECORD +68 -65
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +213 -196
- scripts/match_template_devel.py +1339 -0
- scripts/postprocess.py +40 -78
- scripts/preprocess.py +4 -5
- scripts/preprocessor_gui.py +49 -103
- scripts/pytme_runner.py +46 -69
- tests/preprocessing/test_compose.py +31 -30
- tests/preprocessing/test_frequency_filters.py +17 -32
- tests/preprocessing/test_preprocessor.py +0 -19
- tests/preprocessing/test_utils.py +13 -1
- tests/test_analyzer.py +2 -10
- tests/test_backends.py +47 -18
- tests/test_density.py +72 -13
- tests/test_extensions.py +1 -0
- tests/test_matching_cli.py +23 -9
- tests/test_matching_exhaustive.py +5 -5
- tests/test_matching_utils.py +3 -3
- tests/test_orientations.py +12 -0
- tests/test_rotations.py +13 -23
- tests/test_structure.py +1 -7
- tme/__version__.py +1 -1
- tme/analyzer/aggregation.py +47 -16
- tme/analyzer/base.py +34 -0
- tme/analyzer/peaks.py +26 -13
- tme/analyzer/proxy.py +14 -0
- tme/backends/_jax_utils.py +91 -68
- tme/backends/cupy_backend.py +6 -19
- tme/backends/jax_backend.py +103 -98
- tme/backends/matching_backend.py +0 -17
- tme/backends/mlx_backend.py +0 -29
- tme/backends/npfftw_backend.py +100 -97
- tme/backends/pytorch_backend.py +65 -78
- tme/cli.py +2 -2
- tme/density.py +44 -57
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/_utils.py +52 -24
- tme/filters/bandpass.py +99 -105
- tme/filters/compose.py +133 -39
- tme/filters/ctf.py +51 -102
- tme/filters/reconstruction.py +67 -122
- tme/filters/wedge.py +296 -325
- tme/filters/whitening.py +39 -75
- tme/mask.py +2 -2
- tme/matching_data.py +87 -15
- tme/matching_exhaustive.py +70 -120
- tme/matching_optimization.py +9 -63
- tme/matching_scores.py +261 -100
- tme/matching_utils.py +150 -91
- tme/memory.py +1 -0
- tme/orientations.py +17 -3
- tme/preprocessor.py +0 -239
- tme/rotations.py +102 -70
- tme/structure.py +601 -631
- tme/types.py +1 -0
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/WHEEL +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/entry_points.txt +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/top_level.txt +0 -0
tme/backends/jax_backend.py
CHANGED
@@ -11,7 +11,7 @@ from typing import Tuple, List, Dict, Any
|
|
11
11
|
|
12
12
|
import numpy as np
|
13
13
|
|
14
|
-
from ..types import
|
14
|
+
from ..types import JaxArray
|
15
15
|
from .npfftw_backend import NumpyFFTWBackend, shm_type
|
16
16
|
|
17
17
|
|
@@ -54,25 +54,23 @@ class JaxBackend(NumpyFFTWBackend):
|
|
54
54
|
self.scipy = jsp
|
55
55
|
self._create_ufuncs()
|
56
56
|
|
57
|
-
def from_sharedarr(self, arr:
|
57
|
+
def from_sharedarr(self, arr: JaxArray) -> JaxArray:
|
58
58
|
return arr
|
59
59
|
|
60
60
|
@staticmethod
|
61
|
-
def to_sharedarr(arr:
|
61
|
+
def to_sharedarr(arr: JaxArray, shared_memory_handler: type = None) -> shm_type:
|
62
62
|
return arr
|
63
63
|
|
64
64
|
@staticmethod
|
65
|
-
def at(arr, idx, value) ->
|
66
|
-
|
67
|
-
return arr
|
65
|
+
def at(arr, idx, value) -> JaxArray:
|
66
|
+
return arr.at[idx].set(value)
|
68
67
|
|
69
68
|
def addat(self, arr, indices, values):
|
70
|
-
|
71
|
-
return arr
|
69
|
+
return arr.at[indices].add(values)
|
72
70
|
|
73
71
|
def topleft_pad(
|
74
|
-
self, arr:
|
75
|
-
) ->
|
72
|
+
self, arr: JaxArray, shape: Tuple[int], padval: int = 0
|
73
|
+
) -> JaxArray:
|
76
74
|
b = self.full(shape=shape, dtype=arr.dtype, fill_value=padval)
|
77
75
|
aind = [slice(None, None)] * arr.ndim
|
78
76
|
bind = [slice(None, None)] * arr.ndim
|
@@ -95,6 +93,7 @@ class JaxBackend(NumpyFFTWBackend):
|
|
95
93
|
"maximum",
|
96
94
|
"exp",
|
97
95
|
"mod",
|
96
|
+
"dot",
|
98
97
|
]
|
99
98
|
for ufunc in ufuncs:
|
100
99
|
backend_method = emulate_out(getattr(self._array_backend, ufunc))
|
@@ -105,65 +104,68 @@ class JaxBackend(NumpyFFTWBackend):
|
|
105
104
|
backend_method = getattr(self._array_backend, ufunc)
|
106
105
|
setattr(self, ufunc, staticmethod(backend_method))
|
107
106
|
|
108
|
-
def fill(self, arr:
|
107
|
+
def fill(self, arr: JaxArray, value: float) -> JaxArray:
|
109
108
|
return self._array_backend.full(
|
110
109
|
shape=arr.shape, dtype=arr.dtype, fill_value=value
|
111
110
|
)
|
112
111
|
|
113
|
-
def rfftn(self, arr:
|
112
|
+
def rfftn(self, arr: JaxArray, *args, **kwargs) -> JaxArray:
|
114
113
|
return self._array_backend.fft.rfftn(arr, **kwargs)
|
115
114
|
|
116
|
-
def irfftn(self, arr:
|
115
|
+
def irfftn(self, arr: JaxArray, *args, **kwargs) -> JaxArray:
|
117
116
|
return self._array_backend.fft.irfftn(arr, **kwargs)
|
118
117
|
|
119
|
-
def
|
118
|
+
def _interpolate(self, arr, indices, order: int = 1):
|
119
|
+
ret = self.scipy.ndimage.map_coordinates(arr, indices, order=order)
|
120
|
+
return ret.reshape(arr.shape)
|
121
|
+
|
122
|
+
def _index_grid(self, shape: Tuple[int]) -> JaxArray:
|
123
|
+
"""
|
124
|
+
Create homogeneous coordinate grid.
|
125
|
+
|
126
|
+
Parameters
|
127
|
+
----------
|
128
|
+
shape : tuple of int
|
129
|
+
Shape to create the grid for
|
130
|
+
|
131
|
+
Returns
|
132
|
+
-------
|
133
|
+
JaxArray
|
134
|
+
Coordinate grid of shape (ndim + int(homogeneous), n_points)
|
135
|
+
"""
|
136
|
+
indices = self._array_backend.indices(shape, dtype=self._float_dtype)
|
137
|
+
indices = indices.reshape((len(shape), -1))
|
138
|
+
ones = self._array_backend.ones((1, indices.shape[1]), dtype=indices.dtype)
|
139
|
+
return self._array_backend.concatenate([indices, ones], axis=0)
|
140
|
+
|
141
|
+
def _transform_indices(self, indices: JaxArray, matrix: JaxArray) -> JaxArray:
|
142
|
+
return self._array_backend.matmul(matrix[:-1], indices)
|
143
|
+
|
144
|
+
def _rigid_transform(
|
120
145
|
self,
|
121
|
-
arr:
|
122
|
-
|
123
|
-
out:
|
124
|
-
out_mask:
|
125
|
-
|
126
|
-
arr_mask: BackendArray = None,
|
146
|
+
arr: JaxArray,
|
147
|
+
matrix: JaxArray,
|
148
|
+
out: JaxArray = None,
|
149
|
+
out_mask: JaxArray = None,
|
150
|
+
arr_mask: JaxArray = None,
|
127
151
|
order: int = 1,
|
128
152
|
**kwargs,
|
129
|
-
) -> Tuple[
|
130
|
-
|
131
|
-
|
132
|
-
# This approach is only valid for order <= 1
|
133
|
-
if arr.ndim != rotation_matrix.shape[0]:
|
134
|
-
matrix = self._array_backend.zeros((arr.ndim, arr.ndim))
|
135
|
-
matrix = matrix.at[0, 0].set(1)
|
136
|
-
matrix = matrix.at[1:, 1:].add(rotation_matrix)
|
137
|
-
rotation_matrix = matrix
|
138
|
-
|
139
|
-
center = self.divide(self.to_backend_array(arr.shape) - 1, 2)[:, None]
|
140
|
-
indices = self._array_backend.indices(arr.shape, dtype=self._float_dtype)
|
141
|
-
indices = indices.reshape((arr.ndim, -1))
|
142
|
-
indices = indices.at[:].add(-center)
|
143
|
-
indices = self._array_backend.matmul(rotation_matrix.T, indices)
|
144
|
-
indices = indices.at[:].add(center)
|
145
|
-
if translation is not None:
|
146
|
-
indices = indices.at[:].add(translation)
|
147
|
-
|
148
|
-
out = self.scipy.ndimage.map_coordinates(arr, indices, order=order).reshape(
|
149
|
-
arr.shape
|
150
|
-
)
|
153
|
+
) -> Tuple[JaxArray, JaxArray]:
|
154
|
+
indices = self._index_grid(arr.shape)
|
155
|
+
indices = self._transform_indices(indices, matrix)
|
151
156
|
|
152
|
-
|
153
|
-
if
|
154
|
-
|
155
|
-
|
156
|
-
).reshape(arr_mask.shape)
|
157
|
-
|
158
|
-
return out, out_mask
|
157
|
+
arr = self._interpolate(arr, indices, order)
|
158
|
+
if arr_mask is not None:
|
159
|
+
arr_mask = self._interpolate(out_mask, indices, order)
|
160
|
+
return arr, arr_mask
|
159
161
|
|
160
162
|
def max_score_over_rotations(
|
161
163
|
self,
|
162
|
-
scores:
|
163
|
-
max_scores:
|
164
|
-
rotations:
|
164
|
+
scores: JaxArray,
|
165
|
+
max_scores: JaxArray,
|
166
|
+
rotations: JaxArray,
|
165
167
|
rotation_index: int,
|
166
|
-
) -> Tuple[
|
168
|
+
) -> Tuple[JaxArray, JaxArray]:
|
167
169
|
update = self.greater(max_scores, scores)
|
168
170
|
max_scores = max_scores.at[:].set(self.where(update, max_scores, scores))
|
169
171
|
rotations = rotations.at[:].set(self.where(update, rotations, rotation_index))
|
@@ -212,64 +214,69 @@ class JaxBackend(NumpyFFTWBackend):
|
|
212
214
|
callback_class: object,
|
213
215
|
callback_class_args: Dict,
|
214
216
|
rotate_mask: bool = False,
|
217
|
+
background_correction: str = None,
|
218
|
+
match_projection: bool = False,
|
215
219
|
**kwargs,
|
216
220
|
) -> List:
|
217
221
|
"""
|
218
|
-
Emulates output of :py:meth:`tme.matching_exhaustive.
|
219
|
-
:py:class:`tme.analyzer.MaxScoreOverRotations`.
|
222
|
+
Emulates output of :py:meth:`tme.matching_exhaustive._match_exhaustive`.
|
220
223
|
"""
|
221
224
|
from ._jax_utils import setup_scan
|
225
|
+
from ..matching_utils import setup_filter
|
222
226
|
from ..analyzer import MaxScoreOverRotations
|
223
227
|
|
224
228
|
pad_target = True if len(splits) > 1 else False
|
225
|
-
convolution_mode = "valid" if pad_target else "same"
|
226
229
|
target_pad = matching_data.target_padding(pad_target=pad_target)
|
230
|
+
template_shape = matching_data._batch_shape(
|
231
|
+
matching_data.template.shape, matching_data._target_batch
|
232
|
+
)
|
227
233
|
|
228
|
-
score_mask =
|
234
|
+
score_mask = 1
|
229
235
|
target_shape = tuple(
|
230
236
|
(x.stop - x.start + p) for x, p in zip(splits[0][0], target_pad)
|
231
237
|
)
|
232
|
-
conv_shape, fast_shape, fast_ft_shape, shift = matching_data.
|
233
|
-
target_shape=
|
234
|
-
template_shape=self.to_numpy_array(matching_data._template.shape),
|
235
|
-
batch_mask=self.to_numpy_array(matching_data._batch_mask),
|
238
|
+
conv_shape, fast_shape, fast_ft_shape, shift = matching_data.fourier_padding(
|
239
|
+
target_shape=target_shape
|
236
240
|
)
|
241
|
+
|
237
242
|
analyzer_args = {
|
238
243
|
"shape": fast_shape,
|
239
244
|
"fourier_shift": shift,
|
240
245
|
"fast_shape": fast_shape,
|
241
|
-
"
|
242
|
-
"templateshape": matching_data.template.shape,
|
246
|
+
"templateshape": template_shape,
|
243
247
|
"convolution_shape": conv_shape,
|
244
|
-
"convolution_mode":
|
248
|
+
"convolution_mode": "valid" if pad_target else "same",
|
245
249
|
"thread_safe": False,
|
246
250
|
"aggregate_axis": matching_data._batch_axis(matching_data._batch_mask),
|
247
251
|
"n_rotations": matching_data.rotations.shape[0],
|
248
252
|
"jax_mode": True,
|
249
253
|
}
|
254
|
+
analyzer_args.update(callback_class_args)
|
255
|
+
|
250
256
|
create_target_filter = matching_data.target_filter is not None
|
251
257
|
create_template_filter = matching_data.template_filter is not None
|
252
258
|
create_filter = create_target_filter or create_template_filter
|
253
259
|
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
260
|
+
bg_tmpl = 1
|
261
|
+
if background_correction == "phase-scrambling":
|
262
|
+
bg_tmpl = matching_data.transform_template(
|
263
|
+
"phase_randomization", reverse=True
|
264
|
+
)
|
265
|
+
bg_tmpl = self.astype(bg_tmpl, self._float_dtype)
|
258
266
|
|
267
|
+
rotations = self.astype(matching_data.rotations, self._float_dtype)
|
259
268
|
ret, template_filter, target_filter = [], 1, 1
|
260
269
|
rotation_mapping = {
|
261
|
-
self.tobytes(
|
262
|
-
for i in range(matching_data.rotations.shape[0])
|
270
|
+
self.tobytes(rotations[i]): i for i in range(rotations.shape[0])
|
263
271
|
}
|
264
272
|
for split_start in range(0, len(splits), n_jobs):
|
265
273
|
|
266
274
|
analyzer_kwargs = []
|
267
|
-
|
268
275
|
split_subset = splits[split_start : (split_start + n_jobs)]
|
269
276
|
if not len(split_subset):
|
270
277
|
continue
|
271
278
|
|
272
|
-
targets
|
279
|
+
targets = []
|
273
280
|
for target_split, template_split in split_subset:
|
274
281
|
base, translation_offset = matching_data.subset_by_slice(
|
275
282
|
target_slice=target_split,
|
@@ -278,52 +285,50 @@ class JaxBackend(NumpyFFTWBackend):
|
|
278
285
|
)
|
279
286
|
cur_args = analyzer_args.copy()
|
280
287
|
cur_args["offset"] = translation_offset
|
281
|
-
cur_args.
|
288
|
+
cur_args["targetshape"] = base._output_shape
|
282
289
|
analyzer_kwargs.append(cur_args)
|
283
290
|
|
284
291
|
if pad_target:
|
285
292
|
score_mask = base._score_mask(fast_shape, shift)
|
286
293
|
|
287
|
-
|
288
|
-
|
289
|
-
targets.append(self.topleft_pad(_target, fast_shape))
|
294
|
+
# We prepad outside of jit to guarantee the stack operation works
|
295
|
+
targets.append(self.topleft_pad(base._target, fast_shape))
|
290
296
|
|
291
297
|
if create_filter:
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
)["data"]
|
301
|
-
template_filter = template_filter.at[(0,) * template_filter.ndim].set(0)
|
302
|
-
|
303
|
-
if create_target_filter:
|
304
|
-
target_filter = matching_data.target_filter(
|
305
|
-
shape=fast_shape, **filter_args
|
306
|
-
)["data"]
|
307
|
-
target_filter = target_filter.at[(0,) * target_filter.ndim].set(0)
|
308
|
-
|
309
|
-
create_filter, create_template_filter, create_target_filter = (False,) * 3
|
310
|
-
base, targets = None, self._array_backend.stack(targets)
|
298
|
+
# This is technically inaccurate for whitening filters
|
299
|
+
template_filter, target_filter = setup_filter(
|
300
|
+
matching_data=base,
|
301
|
+
fast_shape=fast_shape,
|
302
|
+
fast_ft_shape=fast_ft_shape,
|
303
|
+
pad_template_filter=False,
|
304
|
+
apply_target_filter=False,
|
305
|
+
)
|
311
306
|
|
307
|
+
# For projection matching we allow broadcasting the first dimension
|
308
|
+
# This becomes problematic when applying per-tilt filters to the target
|
309
|
+
# as the number of tilts does not necessarily coincide with the ideal
|
310
|
+
# fourier shape. Hence we pad the target_filter with zeros here
|
311
|
+
if target_filter.shape != (1,):
|
312
|
+
target_filter = self.topleft_pad(target_filter, fast_ft_shape)
|
313
|
+
|
314
|
+
base, targets = None, self._array_backend.stack(targets)
|
312
315
|
scan_inner = setup_scan(
|
313
316
|
analyzer_kwargs=analyzer_kwargs,
|
314
|
-
|
317
|
+
analyzer=callback_class,
|
315
318
|
fast_shape=fast_shape,
|
316
|
-
rotate_mask=rotate_mask
|
319
|
+
rotate_mask=rotate_mask,
|
320
|
+
match_projection=match_projection,
|
317
321
|
)
|
318
322
|
|
319
323
|
states = scan_inner(
|
320
324
|
self.astype(targets, self._float_dtype),
|
321
325
|
self.astype(matching_data.template, self._float_dtype),
|
322
326
|
self.astype(matching_data.template_mask, self._float_dtype),
|
323
|
-
|
327
|
+
rotations,
|
324
328
|
template_filter,
|
325
329
|
target_filter,
|
326
330
|
score_mask,
|
331
|
+
bg_tmpl,
|
327
332
|
)
|
328
333
|
|
329
334
|
ndim = targets.ndim - 1
|
tme/backends/matching_backend.py
CHANGED
@@ -1105,23 +1105,6 @@ class MatchingBackend(ABC):
|
|
1105
1105
|
def irfftn(self, **kwargs):
|
1106
1106
|
"""Perform an n-D real inverse FFT."""
|
1107
1107
|
|
1108
|
-
def extract_center(self, arr: BackendArray, newshape: Tuple[int]) -> BackendArray:
|
1109
|
-
"""
|
1110
|
-
Extract the centered portion of an array based on a new shape.
|
1111
|
-
|
1112
|
-
Parameters
|
1113
|
-
----------
|
1114
|
-
arr : BackendArray
|
1115
|
-
Input data.
|
1116
|
-
newshape : tuple
|
1117
|
-
Desired shape for the central portion.
|
1118
|
-
|
1119
|
-
Returns
|
1120
|
-
-------
|
1121
|
-
BackendArray
|
1122
|
-
Central portion of the array with shape ``newshape``.
|
1123
|
-
"""
|
1124
|
-
|
1125
1108
|
@abstractmethod
|
1126
1109
|
def compute_convolution_shapes(
|
1127
1110
|
self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
|
tme/backends/mlx_backend.py
CHANGED
@@ -115,35 +115,6 @@ class MLXBackend(NumpyFFTWBackend):
|
|
115
115
|
)
|
116
116
|
return self.to_backend_array(ret)
|
117
117
|
|
118
|
-
def extract_center(self, arr: NDArray, newshape: Tuple[int]) -> NDArray:
|
119
|
-
"""
|
120
|
-
Extract the centered portion of an array based on a new shape.
|
121
|
-
|
122
|
-
Parameters
|
123
|
-
----------
|
124
|
-
arr : NDArray
|
125
|
-
Input array.
|
126
|
-
newshape : tuple
|
127
|
-
Desired shape for the central portion.
|
128
|
-
|
129
|
-
Returns
|
130
|
-
-------
|
131
|
-
NDArray
|
132
|
-
Central portion of the array with shape `newshape`.
|
133
|
-
|
134
|
-
References
|
135
|
-
----------
|
136
|
-
.. [1] https://github.com/scipy/scipy/blob/v1.11.2/scipy/signal/_signaltools.py
|
137
|
-
"""
|
138
|
-
new_shape = self.to_backend_array(newshape)
|
139
|
-
current_shape = self.to_backend_array(arr.shape)
|
140
|
-
starts = self.subtract(current_shape, new_shape)
|
141
|
-
starts = self.astype(self.divide(starts, 2), self._int_dtype)
|
142
|
-
stops = self.astype(self.add(starts, newshape), self._int_dtype)
|
143
|
-
starts, stops = starts.tolist(), stops.tolist()
|
144
|
-
box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
|
145
|
-
return arr[box]
|
146
|
-
|
147
118
|
def rfftn(self, arr, *args, **kwargs):
|
148
119
|
return self.fft.rfftn(arr, stream=self._array_backend.cpu, **kwargs)
|
149
120
|
|