pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. pytme-0.2.2.data/scripts/match_template.py +1187 -0
  2. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/postprocess.py +170 -71
  3. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +179 -86
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +126 -87
  8. scripts/match_template.py +596 -209
  9. scripts/match_template_filters.py +571 -223
  10. scripts/postprocess.py +170 -71
  11. scripts/preprocessor_gui.py +179 -86
  12. scripts/refine_matches.py +567 -159
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +627 -855
  16. tme/backends/__init__.py +41 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +120 -225
  19. tme/backends/jax_backend.py +282 -0
  20. tme/backends/matching_backend.py +464 -388
  21. tme/backends/mlx_backend.py +45 -68
  22. tme/backends/npfftw_backend.py +256 -514
  23. tme/backends/pytorch_backend.py +41 -154
  24. tme/density.py +312 -421
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +366 -303
  27. tme/matching_exhaustive.py +279 -1521
  28. tme/matching_optimization.py +234 -129
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +281 -387
  31. tme/memory.py +377 -0
  32. tme/orientations.py +226 -66
  33. tme/parser.py +3 -4
  34. tme/preprocessing/__init__.py +2 -0
  35. tme/preprocessing/_utils.py +217 -0
  36. tme/preprocessing/composable_filter.py +31 -0
  37. tme/preprocessing/compose.py +55 -0
  38. tme/preprocessing/frequency_filters.py +388 -0
  39. tme/preprocessing/tilt_series.py +1011 -0
  40. tme/preprocessor.py +574 -530
  41. tme/structure.py +495 -189
  42. tme/types.py +5 -3
  43. pytme-0.2.0b0.data/scripts/match_template.py +0 -800
  44. pytme-0.2.0b0.dist-info/METADATA +0 -73
  45. pytme-0.2.0b0.dist-info/RECORD +0 -66
  46. tme/helpers.py +0 -881
  47. tme/matching_constrained.py +0 -195
  48. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  49. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  50. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,282 @@
1
+ """ Backend using jax for template matching.
2
+
3
+ Copyright (c) 2023-2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+ from functools import wraps
8
+ from typing import Tuple, List, Callable
9
+
10
+ from ..types import BackendArray
11
+ from .npfftw_backend import NumpyFFTWBackend, shm_type
12
+
13
+
14
+ def emulate_out(func):
15
+ """
16
+ Adds an out argument to write output of ``func`` to.
17
+ """
18
+
19
+ @wraps(func)
20
+ def inner(*args, out=None, **kwargs):
21
+ ret = func(*args, **kwargs)
22
+ if out is not None:
23
+ out = out.at[:].set(ret)
24
+ return out
25
+ return ret
26
+
27
+ return inner
28
+
29
+
30
+ class JaxBackend(NumpyFFTWBackend):
31
+ """
32
+ A jax-based matching backend.
33
+ """
34
+
35
+ def __init__(self, float_dtype=None, complex_dtype=None, int_dtype=None, **kwargs):
36
+ import jax.scipy as jsp
37
+ import jax.numpy as jnp
38
+
39
+ float_dtype = jnp.float32 if float_dtype is None else float_dtype
40
+ complex_dtype = jnp.complex64 if complex_dtype is None else complex_dtype
41
+ int_dtype = jnp.int32 if int_dtype is None else int_dtype
42
+
43
+ super().__init__(
44
+ array_backend=jnp,
45
+ float_dtype=float_dtype,
46
+ complex_dtype=complex_dtype,
47
+ int_dtype=int_dtype,
48
+ overflow_safe_dtype=float_dtype,
49
+ )
50
+ self.scipy = jsp
51
+ self._create_ufuncs()
52
+ try:
53
+ from ._jax_utils import scan as _
54
+
55
+ self.scan = self._scan
56
+ except Exception:
57
+ pass
58
+
59
+ def from_sharedarr(self, arr: BackendArray) -> BackendArray:
60
+ return arr
61
+
62
+ @staticmethod
63
+ def to_sharedarr(arr: BackendArray, shared_memory_handler: type = None) -> shm_type:
64
+ return arr
65
+
66
+ def topleft_pad(
67
+ self, arr: BackendArray, shape: Tuple[int], padval: int = 0
68
+ ) -> BackendArray:
69
+ b = self.full(shape=shape, dtype=arr.dtype, fill_value=padval)
70
+ aind = [slice(None, None)] * arr.ndim
71
+ bind = [slice(None, None)] * arr.ndim
72
+ for i in range(arr.ndim):
73
+ if arr.shape[i] > shape[i]:
74
+ aind[i] = slice(0, shape[i])
75
+ elif arr.shape[i] < shape[i]:
76
+ bind[i] = slice(0, arr.shape[i])
77
+ b = b.at[tuple(bind)].set(arr[tuple(aind)])
78
+ return b
79
+
80
+ def _create_ufuncs(self):
81
+ ufuncs = [
82
+ "add",
83
+ "subtract",
84
+ "multiply",
85
+ "divide",
86
+ "square",
87
+ "sqrt",
88
+ "maximum",
89
+ ]
90
+ for ufunc in ufuncs:
91
+ backend_method = emulate_out(getattr(self._array_backend, ufunc))
92
+ setattr(self, ufunc, staticmethod(backend_method))
93
+
94
+ ufuncs = ["zeros", "full"]
95
+ for ufunc in ufuncs:
96
+ backend_method = getattr(self._array_backend, ufunc)
97
+ setattr(self, ufunc, staticmethod(backend_method))
98
+
99
+ def fill(self, arr: BackendArray, value: float) -> BackendArray:
100
+ return self._array_backend.full(
101
+ shape=arr.shape, dtype=arr.dtype, fill_value=value
102
+ )
103
+
104
+ def build_fft(
105
+ self,
106
+ fast_shape: Tuple[int],
107
+ fast_ft_shape: Tuple[int],
108
+ inverse_fast_shape: Tuple[int] = None,
109
+ **kwargs,
110
+ ) -> Tuple[Callable, Callable]:
111
+ if inverse_fast_shape is None:
112
+ inverse_fast_shape = fast_shape
113
+
114
+ def rfftn(arr, out, shape=fast_shape):
115
+ return self._array_backend.fft.rfftn(arr, s=shape)
116
+
117
+ def irfftn(arr, out, shape=fast_shape):
118
+ return self._array_backend.fft.irfftn(arr, s=shape)
119
+
120
+ return rfftn, irfftn
121
+
122
+ def compute_convolution_shapes(
123
+ self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
124
+ ) -> Tuple[List[int], List[int], List[int]]:
125
+ conv_shape, fast_shape, fast_ft_shape = super().compute_convolution_shapes(
126
+ arr1_shape, arr2_shape
127
+ )
128
+
129
+ is_odd = fast_shape[-1] % 2
130
+ fast_shape[-1] += is_odd
131
+ fast_ft_shape[-1] += is_odd
132
+
133
+ return conv_shape, fast_shape, fast_ft_shape
134
+
135
+ def rigid_transform(
136
+ self,
137
+ arr: BackendArray,
138
+ rotation_matrix: BackendArray,
139
+ out: BackendArray = None,
140
+ out_mask: BackendArray = None,
141
+ translation: BackendArray = None,
142
+ arr_mask: BackendArray = None,
143
+ order: int = 1,
144
+ **kwargs,
145
+ ) -> Tuple[BackendArray, BackendArray]:
146
+ rotate_mask = arr_mask is not None
147
+ center = self.divide(self.to_backend_array(arr.shape), 2)[:, None]
148
+
149
+ indices = self._array_backend.indices(arr.shape, dtype=self._float_dtype)
150
+ indices = indices.reshape((arr.ndim, -1))
151
+ indices = indices.at[:].add(-center)
152
+ indices = self._array_backend.matmul(rotation_matrix.T, indices)
153
+ indices = indices.at[:].add(center)
154
+ if translation is not None:
155
+ indices = indices.at[:].add(translation)
156
+
157
+ out = self.scipy.ndimage.map_coordinates(arr, indices, order=order).reshape(
158
+ arr.shape
159
+ )
160
+
161
+ out_mask = arr_mask
162
+ if rotate_mask:
163
+ out_mask = self.scipy.ndimage.map_coordinates(
164
+ arr_mask, indices, order=order
165
+ ).reshape(arr_mask.shape)
166
+
167
+ return out, out_mask
168
+
169
+ def max_score_over_rotations(
170
+ self,
171
+ scores: BackendArray,
172
+ max_scores: BackendArray,
173
+ rotations: BackendArray,
174
+ rotation_index: int,
175
+ ) -> Tuple[BackendArray, BackendArray]:
176
+ update = self.greater(max_scores, scores)
177
+ max_scores = max_scores.at[:].set(self.where(update, max_scores, scores))
178
+ rotations = rotations.at[:].set(self.where(update, rotations, rotation_index))
179
+ return max_scores, rotations
180
+
181
+ def _scan(
182
+ self,
183
+ matching_data: type,
184
+ splits: Tuple[Tuple[slice, slice]],
185
+ n_jobs: int,
186
+ callback_class,
187
+ rotate_mask: bool = False,
188
+ **kwargs,
189
+ ) -> List:
190
+ """
191
+ Emulates output of :py:meth:`tme.matching_exhaustive.scan` using
192
+ :py:class:`tme.analyzer.MaxScoreOverRotations`.
193
+ """
194
+ from ._jax_utils import scan as scan_inner
195
+
196
+ pad_target = True if len(splits) > 1 else False
197
+ convolution_mode = "valid" if pad_target else "same"
198
+ target_pad = matching_data.target_padding(pad_target=pad_target)
199
+
200
+ target_shape = tuple(
201
+ (x.stop - x.start + p) for x, p in zip(splits[0][0], target_pad)
202
+ )
203
+ fast_shape, fast_ft_shape, shift = matching_data._fourier_padding(
204
+ target_shape=self.to_numpy_array(target_shape),
205
+ template_shape=self.to_numpy_array(matching_data._template.shape),
206
+ pad_fourier=False,
207
+ )
208
+
209
+ analyzer_args = {
210
+ "convolution_mode": convolution_mode,
211
+ "fourier_shift": shift,
212
+ "targetshape": target_shape,
213
+ "templateshape": matching_data._template.shape,
214
+ }
215
+
216
+ create_target_filter = matching_data.target_filter is not None
217
+ create_template_filter = matching_data.template_filter is not None
218
+ create_filter = create_target_filter or create_template_filter
219
+
220
+ ret, template_filter, target_filter = [], 1, 1
221
+ rotation_mapping = {
222
+ self.tobytes(matching_data.rotations[i]): i
223
+ for i in range(matching_data.rotations.shape[0])
224
+ }
225
+ for split_start in range(0, len(splits), n_jobs):
226
+ split_subset = splits[split_start : (split_start + n_jobs)]
227
+ if not len(split_subset):
228
+ continue
229
+
230
+ targets, translation_offsets = [], []
231
+ for target_split, template_split in split_subset:
232
+ base = matching_data.subset_by_slice(
233
+ target_slice=target_split,
234
+ target_pad=target_pad,
235
+ template_slice=template_split,
236
+ )
237
+ translation_offsets.append(base._translation_offset)
238
+ targets.append(self.topleft_pad(base._target, fast_shape))
239
+
240
+ if create_filter:
241
+ filter_args = {
242
+ "data_rfft": self.fft.rfftn(targets[0]),
243
+ "return_real_fourier": True,
244
+ "shape_is_real_fourier": False,
245
+ }
246
+
247
+ if create_template_filter:
248
+ template_filter = matching_data.template_filter(
249
+ shape=matching_data._template.shape, **filter_args
250
+ )["data"]
251
+ template_filter = template_filter.at[(0,) * template_filter.ndim].set(0)
252
+
253
+ if create_target_filter:
254
+ target_filter = matching_data.template_filter(
255
+ shape=fast_shape, **filter_args
256
+ )["data"]
257
+ target_filter = target_filter.at[(0,) * target_filter.ndim].set(0)
258
+
259
+ create_filter, create_template_filter, create_target_filter = (False,) * 3
260
+ base, targets = None, self._array_backend.stack(targets)
261
+ scores, rotations = scan_inner(
262
+ targets,
263
+ matching_data.template,
264
+ matching_data.template_mask,
265
+ matching_data.rotations,
266
+ template_filter,
267
+ target_filter,
268
+ fast_shape,
269
+ rotate_mask,
270
+ )
271
+
272
+ for index in range(scores.shape[0]):
273
+ temp = callback_class(
274
+ scores=scores[index],
275
+ rotations=rotations[index],
276
+ thread_safe=False,
277
+ offset=translation_offsets[index],
278
+ )
279
+ temp.rotation_mapping = rotation_mapping
280
+ ret.append(tuple(temp._postprocess(**analyzer_args)))
281
+
282
+ return ret