pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/match_template.py +147 -93
  2. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/postprocess.py +67 -26
  3. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +175 -85
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +20 -13
  8. scripts/match_template.py +147 -93
  9. scripts/match_template_filters.py +154 -95
  10. scripts/postprocess.py +67 -26
  11. scripts/preprocessor_gui.py +175 -85
  12. scripts/refine_matches.py +265 -61
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +451 -809
  16. tme/backends/__init__.py +40 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +111 -223
  19. tme/backends/jax_backend.py +214 -150
  20. tme/backends/matching_backend.py +445 -384
  21. tme/backends/mlx_backend.py +32 -59
  22. tme/backends/npfftw_backend.py +239 -507
  23. tme/backends/pytorch_backend.py +21 -145
  24. tme/density.py +233 -363
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +322 -285
  27. tme/matching_exhaustive.py +172 -1493
  28. tme/matching_optimization.py +143 -106
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +280 -386
  31. tme/memory.py +377 -0
  32. tme/orientations.py +52 -12
  33. tme/parser.py +3 -4
  34. tme/preprocessing/_utils.py +61 -32
  35. tme/preprocessing/compose.py +7 -3
  36. tme/preprocessing/frequency_filters.py +49 -39
  37. tme/preprocessing/tilt_series.py +34 -40
  38. tme/preprocessor.py +560 -526
  39. tme/structure.py +491 -188
  40. tme/types.py +5 -3
  41. pytme-0.2.1.dist-info/METADATA +0 -73
  42. pytme-0.2.1.dist-info/RECORD +0 -73
  43. tme/helpers.py +0 -881
  44. tme/matching_constrained.py +0 -195
  45. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  46. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  47. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  48. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  49. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,38 @@
1
1
  """ Backend using jax for template matching.
2
2
 
3
- Copyright (c) 2023 European Molecular Biology Laboratory
3
+ Copyright (c) 2023-2024 European Molecular Biology Laboratory
4
4
 
5
5
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
6
  """
7
- from typing import Tuple, Callable
7
+ from functools import wraps
8
+ from typing import Tuple, List, Callable
9
+
10
+ from ..types import BackendArray
11
+ from .npfftw_backend import NumpyFFTWBackend, shm_type
12
+
13
+
14
+ def emulate_out(func):
15
+ """
16
+ Adds an out argument to write output of ``func`` to.
17
+ """
18
+
19
+ @wraps(func)
20
+ def inner(*args, out=None, **kwargs):
21
+ ret = func(*args, **kwargs)
22
+ if out is not None:
23
+ out = out.at[:].set(ret)
24
+ return out
25
+ return ret
26
+
27
+ return inner
8
28
 
9
- from .npfftw_backend import NumpyFFTWBackend
10
29
 
11
30
  class JaxBackend(NumpyFFTWBackend):
12
- def __init__(
13
- self, float_dtype=None, complex_dtype=None, int_dtype=None, **kwargs
14
- ):
31
+ """
32
+ A jax-based matching backend.
33
+ """
34
+
35
+ def __init__(self, float_dtype=None, complex_dtype=None, int_dtype=None, **kwargs):
15
36
  import jax.scipy as jsp
16
37
  import jax.numpy as jnp
17
38
 
@@ -19,24 +40,33 @@ class JaxBackend(NumpyFFTWBackend):
19
40
  complex_dtype = jnp.complex64 if complex_dtype is None else complex_dtype
20
41
  int_dtype = jnp.int32 if int_dtype is None else int_dtype
21
42
 
22
- self.scipy = jsp
23
43
  super().__init__(
24
44
  array_backend=jnp,
25
45
  float_dtype=float_dtype,
26
46
  complex_dtype=complex_dtype,
27
47
  int_dtype=int_dtype,
48
+ overflow_safe_dtype=float_dtype,
28
49
  )
50
+ self.scipy = jsp
51
+ self._create_ufuncs()
52
+ try:
53
+ from ._jax_utils import scan as _
29
54
 
30
- def to_backend_array(self, arr):
31
- return self._array_backend.asarray(arr)
55
+ self.scan = self._scan
56
+ except Exception:
57
+ pass
32
58
 
33
- def preallocate_array(self, shape: Tuple[int], dtype: type):
34
- arr = self._array_backend.zeros(shape, dtype=dtype)
59
+ def from_sharedarr(self, arr: BackendArray) -> BackendArray:
60
+ return arr
61
+
62
+ @staticmethod
63
+ def to_sharedarr(arr: BackendArray, shared_memory_handler: type = None) -> shm_type:
35
64
  return arr
36
65
 
37
- def topleft_pad(self, arr, shape: Tuple[int], padval: int = 0):
38
- b = self.preallocate_array(shape, arr.dtype)
39
- self.add(b, padval, out=b)
66
+ def topleft_pad(
67
+ self, arr: BackendArray, shape: Tuple[int], padval: int = 0
68
+ ) -> BackendArray:
69
+ b = self.full(shape=shape, dtype=arr.dtype, fill_value=padval)
40
70
  aind = [slice(None, None)] * arr.ndim
41
71
  bind = [slice(None, None)] * arr.ndim
42
72
  for i in range(arr.ndim):
@@ -47,43 +77,29 @@ class JaxBackend(NumpyFFTWBackend):
47
77
  b = b.at[tuple(bind)].set(arr[tuple(aind)])
48
78
  return b
49
79
 
50
-
51
- def add(self, x1, x2, out = None, *args, **kwargs):
52
- x1 = self.to_backend_array(x1)
53
- x2 = self.to_backend_array(x2)
54
- ret = self._array_backend.add(x1, x2, *args, **kwargs)
55
-
56
- if out is not None:
57
- out = out.at[:].set(ret)
58
- return ret
59
-
60
- def subtract(self, x1, x2, out = None, *args, **kwargs):
61
- x1 = self.to_backend_array(x1)
62
- x2 = self.to_backend_array(x2)
63
- ret = self._array_backend.subtract(x1, x2, *args, **kwargs)
64
- if out is not None:
65
- out = out.at[:].set(ret)
66
- return ret
67
-
68
- def multiply(self, x1, x2, out = None, *args, **kwargs):
69
- x1 = self.to_backend_array(x1)
70
- x2 = self.to_backend_array(x2)
71
- ret = self._array_backend.multiply(x1, x2, *args, **kwargs)
72
- if out is not None:
73
- out = out.at[:].set(ret)
74
- return ret
75
-
76
- def divide(self, x1, x2, out = None, *args, **kwargs):
77
- x1 = self.to_backend_array(x1)
78
- x2 = self.to_backend_array(x2)
79
- ret = self._array_backend.divide(x1, x2, *args, **kwargs)
80
- if out is not None:
81
- out = out.at[:].set(ret)
82
- return ret
83
-
84
- def fill(self, arr, value: float) -> None:
85
- arr.at[:].set(value)
86
-
80
+ def _create_ufuncs(self):
81
+ ufuncs = [
82
+ "add",
83
+ "subtract",
84
+ "multiply",
85
+ "divide",
86
+ "square",
87
+ "sqrt",
88
+ "maximum",
89
+ ]
90
+ for ufunc in ufuncs:
91
+ backend_method = emulate_out(getattr(self._array_backend, ufunc))
92
+ setattr(self, ufunc, staticmethod(backend_method))
93
+
94
+ ufuncs = ["zeros", "full"]
95
+ for ufunc in ufuncs:
96
+ backend_method = getattr(self._array_backend, ufunc)
97
+ setattr(self, ufunc, staticmethod(backend_method))
98
+
99
+ def fill(self, arr: BackendArray, value: float) -> BackendArray:
100
+ return self._array_backend.full(
101
+ shape=arr.shape, dtype=arr.dtype, fill_value=value
102
+ )
87
103
 
88
104
  def build_fft(
89
105
  self,
@@ -92,127 +108,175 @@ class JaxBackend(NumpyFFTWBackend):
92
108
  inverse_fast_shape: Tuple[int] = None,
93
109
  **kwargs,
94
110
  ) -> Tuple[Callable, Callable]:
95
- """
96
- Build fft builder functions.
97
-
98
- Parameters
99
- ----------
100
- fast_shape : tuple
101
- Tuple of integers corresponding to fast convolution shape
102
- (see :py:meth:`PytorchBackend.compute_convolution_shapes`).
103
- fast_ft_shape : tuple
104
- Tuple of integers corresponding to the shape of the Fourier
105
- transform array (see :py:meth:`PytorchBackend.compute_convolution_shapes`).
106
- inverse_fast_shape : tuple, optional
107
- Output shape of the inverse Fourier transform. By default fast_shape.
108
- **kwargs : dict, optional
109
- Unused keyword arguments.
110
-
111
- Returns
112
- -------
113
- tuple
114
- Tupple containing callable rfft and irfft object.
115
- """
116
111
  if inverse_fast_shape is None:
117
112
  inverse_fast_shape = fast_shape
118
113
 
119
- def rfftn(
120
- arr, out, shape: Tuple[int] = fast_shape
121
- ) -> None:
122
- out = out.at[:].set(self._array_backend.fft.rfftn(arr, s=shape))
114
+ def rfftn(arr, out, shape=fast_shape):
115
+ return self._array_backend.fft.rfftn(arr, s=shape)
123
116
 
124
- def irfftn(
125
- arr, out, shape: Tuple[int] = inverse_fast_shape
126
- ) -> None:
127
- out = out.at[:].set(self._array_backend.fft.irfftn(arr, s=shape))
117
+ def irfftn(arr, out, shape=fast_shape):
118
+ return self._array_backend.fft.irfftn(arr, s=shape)
128
119
 
129
120
  return rfftn, irfftn
130
121
 
131
- def sharedarr_to_arr(self, shm, shape: Tuple[int], dtype: str):
132
- return shm
122
+ def compute_convolution_shapes(
123
+ self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
124
+ ) -> Tuple[List[int], List[int], List[int]]:
125
+ conv_shape, fast_shape, fast_ft_shape = super().compute_convolution_shapes(
126
+ arr1_shape, arr2_shape
127
+ )
133
128
 
134
- @staticmethod
135
- def arr_to_sharedarr(arr, shared_memory_handler: type = None):
136
- return arr
129
+ is_odd = fast_shape[-1] % 2
130
+ fast_shape[-1] += is_odd
131
+ fast_ft_shape[-1] += is_odd
137
132
 
138
- def rotate_array(
133
+ return conv_shape, fast_shape, fast_ft_shape
134
+
135
+ def rigid_transform(
139
136
  self,
140
- arr,
141
- rotation_matrix,
142
- arr_mask = None,
143
- translation = None,
144
- use_geometric_center: bool = False,
145
- out = None,
146
- out_mask = None,
147
- order: int = 3,
148
- ) -> None:
137
+ arr: BackendArray,
138
+ rotation_matrix: BackendArray,
139
+ out: BackendArray = None,
140
+ out_mask: BackendArray = None,
141
+ translation: BackendArray = None,
142
+ arr_mask: BackendArray = None,
143
+ order: int = 1,
144
+ **kwargs,
145
+ ) -> Tuple[BackendArray, BackendArray]:
149
146
  rotate_mask = arr_mask is not None
150
- return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
151
- translation = self.zeros(arr.ndim) if translation is None else translation
152
-
153
- indices = self._array_backend.indices(arr.shape).reshape(
154
- (len(arr.shape), -1)
155
- ).astype(self._float_dtype)
147
+ center = self.divide(self.to_backend_array(arr.shape), 2)[:, None]
156
148
 
157
- center = self.divide(arr.shape, 2)
158
- if not use_geometric_center:
159
- center = self.center_of_mass(arr, cutoff=0)
160
- center = center[:, None]
149
+ indices = self._array_backend.indices(arr.shape, dtype=self._float_dtype)
150
+ indices = indices.reshape((arr.ndim, -1))
161
151
  indices = indices.at[:].add(-center)
162
- rotation_matrix = self._array_backend.linalg.inv(rotation_matrix)
163
- indices = self._array_backend.matmul(rotation_matrix, indices)
152
+ indices = self._array_backend.matmul(rotation_matrix.T, indices)
164
153
  indices = indices.at[:].add(center)
154
+ if translation is not None:
155
+ indices = indices.at[:].add(translation)
165
156
 
166
- out = self.zeros_like(arr) if out is None else out
167
- out_slice = tuple(slice(0, stop) for stop in arr.shape)
168
-
169
- out = out.at[out_slice].set(
170
- self.scipy.ndimage.map_coordinates(
171
- arr, indices, order=order
172
- ).reshape(arr.shape)
157
+ out = self.scipy.ndimage.map_coordinates(arr, indices, order=order).reshape(
158
+ arr.shape
173
159
  )
174
160
 
161
+ out_mask = arr_mask
175
162
  if rotate_mask:
176
- out_mask = self.zeros_like(arr_mask) if out_mask is None else out_mask
177
- out_mask_slice = tuple(slice(0, stop) for stop in arr_mask.shape)
178
- out_mask = out_mask.at[out_mask_slice].set(
179
- self.scipy.ndimage.map_coordinates(
180
- arr_mask, indices, order=order
181
- ).reshape(arr.shape)
182
- )
163
+ out_mask = self.scipy.ndimage.map_coordinates(
164
+ arr_mask, indices, order=order
165
+ ).reshape(arr_mask.shape)
183
166
 
184
- match return_type:
185
- case 0:
186
- return None
187
- case 1:
188
- return out
189
- case 2:
190
- return out_mask
191
- case 3:
192
- return out, out_mask
167
+ return out, out_mask
193
168
 
194
169
  def max_score_over_rotations(
195
170
  self,
196
- score_space,
197
- internal_scores,
198
- internal_rotations,
171
+ scores: BackendArray,
172
+ max_scores: BackendArray,
173
+ rotations: BackendArray,
199
174
  rotation_index: int,
200
- ):
175
+ ) -> Tuple[BackendArray, BackendArray]:
176
+ update = self.greater(max_scores, scores)
177
+ max_scores = max_scores.at[:].set(self.where(update, max_scores, scores))
178
+ rotations = rotations.at[:].set(self.where(update, rotations, rotation_index))
179
+ return max_scores, rotations
180
+
181
+ def _scan(
182
+ self,
183
+ matching_data: type,
184
+ splits: Tuple[Tuple[slice, slice]],
185
+ n_jobs: int,
186
+ callback_class,
187
+ rotate_mask: bool = False,
188
+ **kwargs,
189
+ ) -> List:
201
190
  """
202
- Modify internal_scores and internal_rotations inplace with scores and rotation
203
- index respectively, wherever score_sapce is larger than internal scores.
204
-
205
- Parameters
206
- ----------
207
- score_space : CupyArray
208
- The score space to compare against internal_scores.
209
- internal_scores : CupyArray
210
- The internal scores to update with maximum scores.
211
- internal_rotations : CupyArray
212
- The internal rotations corresponding to the maximum scores.
213
- rotation_index : int
214
- The index representing the current rotation.
191
+ Emulates output of :py:meth:`tme.matching_exhaustive.scan` using
192
+ :py:class:`tme.analyzer.MaxScoreOverRotations`.
215
193
  """
216
- indices = score_space > internal_scores
217
- internal_scores.at[indices].set(score_space[indices])
218
- internal_rotations.at[indices].set(rotation_index)
194
+ from ._jax_utils import scan as scan_inner
195
+
196
+ pad_target = True if len(splits) > 1 else False
197
+ convolution_mode = "valid" if pad_target else "same"
198
+ target_pad = matching_data.target_padding(pad_target=pad_target)
199
+
200
+ target_shape = tuple(
201
+ (x.stop - x.start + p) for x, p in zip(splits[0][0], target_pad)
202
+ )
203
+ fast_shape, fast_ft_shape, shift = matching_data._fourier_padding(
204
+ target_shape=self.to_numpy_array(target_shape),
205
+ template_shape=self.to_numpy_array(matching_data._template.shape),
206
+ pad_fourier=False,
207
+ )
208
+
209
+ analyzer_args = {
210
+ "convolution_mode": convolution_mode,
211
+ "fourier_shift": shift,
212
+ "targetshape": target_shape,
213
+ "templateshape": matching_data._template.shape,
214
+ }
215
+
216
+ create_target_filter = matching_data.target_filter is not None
217
+ create_template_filter = matching_data.template_filter is not None
218
+ create_filter = create_target_filter or create_template_filter
219
+
220
+ ret, template_filter, target_filter = [], 1, 1
221
+ rotation_mapping = {
222
+ self.tobytes(matching_data.rotations[i]): i
223
+ for i in range(matching_data.rotations.shape[0])
224
+ }
225
+ for split_start in range(0, len(splits), n_jobs):
226
+ split_subset = splits[split_start : (split_start + n_jobs)]
227
+ if not len(split_subset):
228
+ continue
229
+
230
+ targets, translation_offsets = [], []
231
+ for target_split, template_split in split_subset:
232
+ base = matching_data.subset_by_slice(
233
+ target_slice=target_split,
234
+ target_pad=target_pad,
235
+ template_slice=template_split,
236
+ )
237
+ translation_offsets.append(base._translation_offset)
238
+ targets.append(self.topleft_pad(base._target, fast_shape))
239
+
240
+ if create_filter:
241
+ filter_args = {
242
+ "data_rfft": self.fft.rfftn(targets[0]),
243
+ "return_real_fourier": True,
244
+ "shape_is_real_fourier": False,
245
+ }
246
+
247
+ if create_template_filter:
248
+ template_filter = matching_data.template_filter(
249
+ shape=matching_data._template.shape, **filter_args
250
+ )["data"]
251
+ template_filter = template_filter.at[(0,) * template_filter.ndim].set(0)
252
+
253
+ if create_target_filter:
254
+ target_filter = matching_data.template_filter(
255
+ shape=fast_shape, **filter_args
256
+ )["data"]
257
+ target_filter = target_filter.at[(0,) * target_filter.ndim].set(0)
258
+
259
+ create_filter, create_template_filter, create_target_filter = (False,) * 3
260
+ base, targets = None, self._array_backend.stack(targets)
261
+ scores, rotations = scan_inner(
262
+ targets,
263
+ matching_data.template,
264
+ matching_data.template_mask,
265
+ matching_data.rotations,
266
+ template_filter,
267
+ target_filter,
268
+ fast_shape,
269
+ rotate_mask,
270
+ )
271
+
272
+ for index in range(scores.shape[0]):
273
+ temp = callback_class(
274
+ scores=scores[index],
275
+ rotations=rotations[index],
276
+ thread_safe=False,
277
+ offset=translation_offsets[index],
278
+ )
279
+ temp.rotation_mapping = rotation_mapping
280
+ ret.append(tuple(temp._postprocess(**analyzer_args)))
281
+
282
+ return ret