pytme 0.3.1.post2__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2.dev0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. pytme-0.3.2.dev0.data/scripts/estimate_ram_usage.py +97 -0
  2. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/match_template.py +213 -196
  3. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/postprocess.py +40 -78
  4. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/preprocess.py +4 -5
  5. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/preprocessor_gui.py +49 -103
  6. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/pytme_runner.py +46 -69
  7. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/METADATA +2 -1
  8. pytme-0.3.2.dev0.dist-info/RECORD +136 -0
  9. scripts/estimate_ram_usage.py +97 -0
  10. scripts/match_template.py +213 -196
  11. scripts/match_template_devel.py +1339 -0
  12. scripts/postprocess.py +40 -78
  13. scripts/preprocess.py +4 -5
  14. scripts/preprocessor_gui.py +49 -103
  15. scripts/pytme_runner.py +46 -69
  16. tests/preprocessing/test_compose.py +31 -30
  17. tests/preprocessing/test_frequency_filters.py +17 -32
  18. tests/preprocessing/test_preprocessor.py +0 -19
  19. tests/preprocessing/test_utils.py +13 -1
  20. tests/test_analyzer.py +2 -10
  21. tests/test_backends.py +47 -18
  22. tests/test_density.py +72 -13
  23. tests/test_extensions.py +1 -0
  24. tests/test_matching_cli.py +23 -9
  25. tests/test_matching_exhaustive.py +5 -5
  26. tests/test_matching_utils.py +3 -3
  27. tests/test_orientations.py +12 -0
  28. tests/test_rotations.py +13 -23
  29. tests/test_structure.py +1 -7
  30. tme/__version__.py +1 -1
  31. tme/analyzer/aggregation.py +47 -16
  32. tme/analyzer/base.py +34 -0
  33. tme/analyzer/peaks.py +26 -13
  34. tme/analyzer/proxy.py +14 -0
  35. tme/backends/_jax_utils.py +91 -68
  36. tme/backends/cupy_backend.py +6 -19
  37. tme/backends/jax_backend.py +103 -98
  38. tme/backends/matching_backend.py +0 -17
  39. tme/backends/mlx_backend.py +0 -29
  40. tme/backends/npfftw_backend.py +100 -97
  41. tme/backends/pytorch_backend.py +65 -78
  42. tme/cli.py +2 -2
  43. tme/density.py +44 -57
  44. tme/extensions.cpython-311-darwin.so +0 -0
  45. tme/filters/_utils.py +52 -24
  46. tme/filters/bandpass.py +99 -105
  47. tme/filters/compose.py +133 -39
  48. tme/filters/ctf.py +51 -102
  49. tme/filters/reconstruction.py +67 -122
  50. tme/filters/wedge.py +296 -325
  51. tme/filters/whitening.py +39 -75
  52. tme/mask.py +2 -2
  53. tme/matching_data.py +87 -15
  54. tme/matching_exhaustive.py +70 -120
  55. tme/matching_optimization.py +9 -63
  56. tme/matching_scores.py +261 -100
  57. tme/matching_utils.py +150 -91
  58. tme/memory.py +1 -0
  59. tme/orientations.py +17 -3
  60. tme/preprocessor.py +0 -239
  61. tme/rotations.py +102 -70
  62. tme/structure.py +601 -631
  63. tme/types.py +1 -0
  64. pytme-0.3.1.post2.dist-info/RECORD +0 -133
  65. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/estimate_memory_usage.py +0 -0
  66. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/WHEEL +0 -0
  67. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/entry_points.txt +0 -0
  68. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/licenses/LICENSE +0 -0
  69. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,7 @@ from typing import Tuple, List, Dict, Any
11
11
 
12
12
  import numpy as np
13
13
 
14
- from ..types import BackendArray
14
+ from ..types import JaxArray
15
15
  from .npfftw_backend import NumpyFFTWBackend, shm_type
16
16
 
17
17
 
@@ -54,25 +54,23 @@ class JaxBackend(NumpyFFTWBackend):
54
54
  self.scipy = jsp
55
55
  self._create_ufuncs()
56
56
 
57
- def from_sharedarr(self, arr: BackendArray) -> BackendArray:
57
+ def from_sharedarr(self, arr: JaxArray) -> JaxArray:
58
58
  return arr
59
59
 
60
60
  @staticmethod
61
- def to_sharedarr(arr: BackendArray, shared_memory_handler: type = None) -> shm_type:
61
+ def to_sharedarr(arr: JaxArray, shared_memory_handler: type = None) -> shm_type:
62
62
  return arr
63
63
 
64
64
  @staticmethod
65
- def at(arr, idx, value) -> BackendArray:
66
- arr = arr.at[idx].set(value)
67
- return arr
65
+ def at(arr, idx, value) -> JaxArray:
66
+ return arr.at[idx].set(value)
68
67
 
69
68
  def addat(self, arr, indices, values):
70
- arr = arr.at[indices].add(values)
71
- return arr
69
+ return arr.at[indices].add(values)
72
70
 
73
71
  def topleft_pad(
74
- self, arr: BackendArray, shape: Tuple[int], padval: int = 0
75
- ) -> BackendArray:
72
+ self, arr: JaxArray, shape: Tuple[int], padval: int = 0
73
+ ) -> JaxArray:
76
74
  b = self.full(shape=shape, dtype=arr.dtype, fill_value=padval)
77
75
  aind = [slice(None, None)] * arr.ndim
78
76
  bind = [slice(None, None)] * arr.ndim
@@ -95,6 +93,7 @@ class JaxBackend(NumpyFFTWBackend):
95
93
  "maximum",
96
94
  "exp",
97
95
  "mod",
96
+ "dot",
98
97
  ]
99
98
  for ufunc in ufuncs:
100
99
  backend_method = emulate_out(getattr(self._array_backend, ufunc))
@@ -105,65 +104,68 @@ class JaxBackend(NumpyFFTWBackend):
105
104
  backend_method = getattr(self._array_backend, ufunc)
106
105
  setattr(self, ufunc, staticmethod(backend_method))
107
106
 
108
- def fill(self, arr: BackendArray, value: float) -> BackendArray:
107
+ def fill(self, arr: JaxArray, value: float) -> JaxArray:
109
108
  return self._array_backend.full(
110
109
  shape=arr.shape, dtype=arr.dtype, fill_value=value
111
110
  )
112
111
 
113
- def rfftn(self, arr: BackendArray, *args, **kwargs) -> BackendArray:
112
+ def rfftn(self, arr: JaxArray, *args, **kwargs) -> JaxArray:
114
113
  return self._array_backend.fft.rfftn(arr, **kwargs)
115
114
 
116
- def irfftn(self, arr: BackendArray, *args, **kwargs) -> BackendArray:
115
+ def irfftn(self, arr: JaxArray, *args, **kwargs) -> JaxArray:
117
116
  return self._array_backend.fft.irfftn(arr, **kwargs)
118
117
 
119
- def rigid_transform(
118
+ def _interpolate(self, arr, indices, order: int = 1):
119
+ ret = self.scipy.ndimage.map_coordinates(arr, indices, order=order)
120
+ return ret.reshape(arr.shape)
121
+
122
+ def _index_grid(self, shape: Tuple[int]) -> JaxArray:
123
+ """
124
+ Create homogeneous coordinate grid.
125
+
126
+ Parameters
127
+ ----------
128
+ shape : tuple of int
129
+ Shape to create the grid for
130
+
131
+ Returns
132
+ -------
133
+ JaxArray
134
+ Coordinate grid of shape (ndim + int(homogeneous), n_points)
135
+ """
136
+ indices = self._array_backend.indices(shape, dtype=self._float_dtype)
137
+ indices = indices.reshape((len(shape), -1))
138
+ ones = self._array_backend.ones((1, indices.shape[1]), dtype=indices.dtype)
139
+ return self._array_backend.concatenate([indices, ones], axis=0)
140
+
141
+ def _transform_indices(self, indices: JaxArray, matrix: JaxArray) -> JaxArray:
142
+ return self._array_backend.matmul(matrix[:-1], indices)
143
+
144
+ def _rigid_transform(
120
145
  self,
121
- arr: BackendArray,
122
- rotation_matrix: BackendArray,
123
- out: BackendArray = None,
124
- out_mask: BackendArray = None,
125
- translation: BackendArray = None,
126
- arr_mask: BackendArray = None,
146
+ arr: JaxArray,
147
+ matrix: JaxArray,
148
+ out: JaxArray = None,
149
+ out_mask: JaxArray = None,
150
+ arr_mask: JaxArray = None,
127
151
  order: int = 1,
128
152
  **kwargs,
129
- ) -> Tuple[BackendArray, BackendArray]:
130
- rotate_mask = arr_mask is not None
131
-
132
- # This approach is only valid for order <= 1
133
- if arr.ndim != rotation_matrix.shape[0]:
134
- matrix = self._array_backend.zeros((arr.ndim, arr.ndim))
135
- matrix = matrix.at[0, 0].set(1)
136
- matrix = matrix.at[1:, 1:].add(rotation_matrix)
137
- rotation_matrix = matrix
138
-
139
- center = self.divide(self.to_backend_array(arr.shape) - 1, 2)[:, None]
140
- indices = self._array_backend.indices(arr.shape, dtype=self._float_dtype)
141
- indices = indices.reshape((arr.ndim, -1))
142
- indices = indices.at[:].add(-center)
143
- indices = self._array_backend.matmul(rotation_matrix.T, indices)
144
- indices = indices.at[:].add(center)
145
- if translation is not None:
146
- indices = indices.at[:].add(translation)
147
-
148
- out = self.scipy.ndimage.map_coordinates(arr, indices, order=order).reshape(
149
- arr.shape
150
- )
153
+ ) -> Tuple[JaxArray, JaxArray]:
154
+ indices = self._index_grid(arr.shape)
155
+ indices = self._transform_indices(indices, matrix)
151
156
 
152
- out_mask = arr_mask
153
- if rotate_mask:
154
- out_mask = self.scipy.ndimage.map_coordinates(
155
- arr_mask, indices, order=order
156
- ).reshape(arr_mask.shape)
157
-
158
- return out, out_mask
157
+ arr = self._interpolate(arr, indices, order)
158
+ if arr_mask is not None:
159
+ arr_mask = self._interpolate(out_mask, indices, order)
160
+ return arr, arr_mask
159
161
 
160
162
  def max_score_over_rotations(
161
163
  self,
162
- scores: BackendArray,
163
- max_scores: BackendArray,
164
- rotations: BackendArray,
164
+ scores: JaxArray,
165
+ max_scores: JaxArray,
166
+ rotations: JaxArray,
165
167
  rotation_index: int,
166
- ) -> Tuple[BackendArray, BackendArray]:
168
+ ) -> Tuple[JaxArray, JaxArray]:
167
169
  update = self.greater(max_scores, scores)
168
170
  max_scores = max_scores.at[:].set(self.where(update, max_scores, scores))
169
171
  rotations = rotations.at[:].set(self.where(update, rotations, rotation_index))
@@ -212,64 +214,69 @@ class JaxBackend(NumpyFFTWBackend):
212
214
  callback_class: object,
213
215
  callback_class_args: Dict,
214
216
  rotate_mask: bool = False,
217
+ background_correction: str = None,
218
+ match_projection: bool = False,
215
219
  **kwargs,
216
220
  ) -> List:
217
221
  """
218
- Emulates output of :py:meth:`tme.matching_exhaustive.scan` using
219
- :py:class:`tme.analyzer.MaxScoreOverRotations`.
222
+ Emulates output of :py:meth:`tme.matching_exhaustive._match_exhaustive`.
220
223
  """
221
224
  from ._jax_utils import setup_scan
225
+ from ..matching_utils import setup_filter
222
226
  from ..analyzer import MaxScoreOverRotations
223
227
 
224
228
  pad_target = True if len(splits) > 1 else False
225
- convolution_mode = "valid" if pad_target else "same"
226
229
  target_pad = matching_data.target_padding(pad_target=pad_target)
230
+ template_shape = matching_data._batch_shape(
231
+ matching_data.template.shape, matching_data._target_batch
232
+ )
227
233
 
228
- score_mask = self.full((1,), fill_value=1, dtype=bool)
234
+ score_mask = 1
229
235
  target_shape = tuple(
230
236
  (x.stop - x.start + p) for x, p in zip(splits[0][0], target_pad)
231
237
  )
232
- conv_shape, fast_shape, fast_ft_shape, shift = matching_data._fourier_padding(
233
- target_shape=self.to_numpy_array(target_shape),
234
- template_shape=self.to_numpy_array(matching_data._template.shape),
235
- batch_mask=self.to_numpy_array(matching_data._batch_mask),
238
+ conv_shape, fast_shape, fast_ft_shape, shift = matching_data.fourier_padding(
239
+ target_shape=target_shape
236
240
  )
241
+
237
242
  analyzer_args = {
238
243
  "shape": fast_shape,
239
244
  "fourier_shift": shift,
240
245
  "fast_shape": fast_shape,
241
- "targetshape": target_shape,
242
- "templateshape": matching_data.template.shape,
246
+ "templateshape": template_shape,
243
247
  "convolution_shape": conv_shape,
244
- "convolution_mode": convolution_mode,
248
+ "convolution_mode": "valid" if pad_target else "same",
245
249
  "thread_safe": False,
246
250
  "aggregate_axis": matching_data._batch_axis(matching_data._batch_mask),
247
251
  "n_rotations": matching_data.rotations.shape[0],
248
252
  "jax_mode": True,
249
253
  }
254
+ analyzer_args.update(callback_class_args)
255
+
250
256
  create_target_filter = matching_data.target_filter is not None
251
257
  create_template_filter = matching_data.template_filter is not None
252
258
  create_filter = create_target_filter or create_template_filter
253
259
 
254
- # Applying the filter leads to more FFTs
255
- fastt_shape = matching_data._template.shape
256
- if create_template_filter:
257
- fastt_shape = matching_data._template.shape
260
+ bg_tmpl = 1
261
+ if background_correction == "phase-scrambling":
262
+ bg_tmpl = matching_data.transform_template(
263
+ "phase_randomization", reverse=True
264
+ )
265
+ bg_tmpl = self.astype(bg_tmpl, self._float_dtype)
258
266
 
267
+ rotations = self.astype(matching_data.rotations, self._float_dtype)
259
268
  ret, template_filter, target_filter = [], 1, 1
260
269
  rotation_mapping = {
261
- self.tobytes(matching_data.rotations[i]): i
262
- for i in range(matching_data.rotations.shape[0])
270
+ self.tobytes(rotations[i]): i for i in range(rotations.shape[0])
263
271
  }
264
272
  for split_start in range(0, len(splits), n_jobs):
265
273
 
266
274
  analyzer_kwargs = []
267
-
268
275
  split_subset = splits[split_start : (split_start + n_jobs)]
269
276
  if not len(split_subset):
270
277
  continue
271
278
 
272
- targets, translation_offsets = [], []
279
+ targets = []
273
280
  for target_split, template_split in split_subset:
274
281
  base, translation_offset = matching_data.subset_by_slice(
275
282
  target_slice=target_split,
@@ -278,52 +285,50 @@ class JaxBackend(NumpyFFTWBackend):
278
285
  )
279
286
  cur_args = analyzer_args.copy()
280
287
  cur_args["offset"] = translation_offset
281
- cur_args.update(callback_class_args)
288
+ cur_args["targetshape"] = base._output_shape
282
289
  analyzer_kwargs.append(cur_args)
283
290
 
284
291
  if pad_target:
285
292
  score_mask = base._score_mask(fast_shape, shift)
286
293
 
287
- _target = self.astype(base._target, self._float_dtype)
288
- translation_offsets.append(translation_offset)
289
- targets.append(self.topleft_pad(_target, fast_shape))
294
+ # We prepad outside of jit to guarantee the stack operation works
295
+ targets.append(self.topleft_pad(base._target, fast_shape))
290
296
 
291
297
  if create_filter:
292
- filter_args = {
293
- "data_rfft": self.fft.rfftn(targets[0]),
294
- "return_real_fourier": True,
295
- }
296
-
297
- if create_template_filter:
298
- template_filter = matching_data.template_filter(
299
- shape=fastt_shape, **filter_args
300
- )["data"]
301
- template_filter = template_filter.at[(0,) * template_filter.ndim].set(0)
302
-
303
- if create_target_filter:
304
- target_filter = matching_data.target_filter(
305
- shape=fast_shape, **filter_args
306
- )["data"]
307
- target_filter = target_filter.at[(0,) * target_filter.ndim].set(0)
308
-
309
- create_filter, create_template_filter, create_target_filter = (False,) * 3
310
- base, targets = None, self._array_backend.stack(targets)
298
+ # This is technically inaccurate for whitening filters
299
+ template_filter, target_filter = setup_filter(
300
+ matching_data=base,
301
+ fast_shape=fast_shape,
302
+ fast_ft_shape=fast_ft_shape,
303
+ pad_template_filter=False,
304
+ apply_target_filter=False,
305
+ )
311
306
 
307
+ # For projection matching we allow broadcasting the first dimension
308
+ # This becomes problematic when applying per-tilt filters to the target
309
+ # as the number of tilts does not necessarily coincide with the ideal
310
+ # fourier shape. Hence we pad the target_filter with zeros here
311
+ if target_filter.shape != (1,):
312
+ target_filter = self.topleft_pad(target_filter, fast_ft_shape)
313
+
314
+ base, targets = None, self._array_backend.stack(targets)
312
315
  scan_inner = setup_scan(
313
316
  analyzer_kwargs=analyzer_kwargs,
314
- callback_class=callback_class,
317
+ analyzer=callback_class,
315
318
  fast_shape=fast_shape,
316
- rotate_mask=rotate_mask
319
+ rotate_mask=rotate_mask,
320
+ match_projection=match_projection,
317
321
  )
318
322
 
319
323
  states = scan_inner(
320
324
  self.astype(targets, self._float_dtype),
321
325
  self.astype(matching_data.template, self._float_dtype),
322
326
  self.astype(matching_data.template_mask, self._float_dtype),
323
- matching_data.rotations,
327
+ rotations,
324
328
  template_filter,
325
329
  target_filter,
326
330
  score_mask,
331
+ bg_tmpl,
327
332
  )
328
333
 
329
334
  ndim = targets.ndim - 1
@@ -1105,23 +1105,6 @@ class MatchingBackend(ABC):
1105
1105
  def irfftn(self, **kwargs):
1106
1106
  """Perform an n-D real inverse FFT."""
1107
1107
 
1108
- def extract_center(self, arr: BackendArray, newshape: Tuple[int]) -> BackendArray:
1109
- """
1110
- Extract the centered portion of an array based on a new shape.
1111
-
1112
- Parameters
1113
- ----------
1114
- arr : BackendArray
1115
- Input data.
1116
- newshape : tuple
1117
- Desired shape for the central portion.
1118
-
1119
- Returns
1120
- -------
1121
- BackendArray
1122
- Central portion of the array with shape ``newshape``.
1123
- """
1124
-
1125
1108
  @abstractmethod
1126
1109
  def compute_convolution_shapes(
1127
1110
  self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
@@ -115,35 +115,6 @@ class MLXBackend(NumpyFFTWBackend):
115
115
  )
116
116
  return self.to_backend_array(ret)
117
117
 
118
- def extract_center(self, arr: NDArray, newshape: Tuple[int]) -> NDArray:
119
- """
120
- Extract the centered portion of an array based on a new shape.
121
-
122
- Parameters
123
- ----------
124
- arr : NDArray
125
- Input array.
126
- newshape : tuple
127
- Desired shape for the central portion.
128
-
129
- Returns
130
- -------
131
- NDArray
132
- Central portion of the array with shape `newshape`.
133
-
134
- References
135
- ----------
136
- .. [1] https://github.com/scipy/scipy/blob/v1.11.2/scipy/signal/_signaltools.py
137
- """
138
- new_shape = self.to_backend_array(newshape)
139
- current_shape = self.to_backend_array(arr.shape)
140
- starts = self.subtract(current_shape, new_shape)
141
- starts = self.astype(self.divide(starts, 2), self._int_dtype)
142
- stops = self.astype(self.add(starts, newshape), self._int_dtype)
143
- starts, stops = starts.tolist(), stops.tolist()
144
- box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
145
- return arr[box]
146
-
147
118
  def rfftn(self, arr, *args, **kwargs):
148
119
  return self.fft.rfftn(arr, stream=self._array_backend.cpu, **kwargs)
149
120