pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. pytme-0.2.2.data/scripts/match_template.py +1187 -0
  2. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/postprocess.py +170 -71
  3. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +179 -86
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +126 -87
  8. scripts/match_template.py +596 -209
  9. scripts/match_template_filters.py +571 -223
  10. scripts/postprocess.py +170 -71
  11. scripts/preprocessor_gui.py +179 -86
  12. scripts/refine_matches.py +567 -159
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +627 -855
  16. tme/backends/__init__.py +41 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +120 -225
  19. tme/backends/jax_backend.py +282 -0
  20. tme/backends/matching_backend.py +464 -388
  21. tme/backends/mlx_backend.py +45 -68
  22. tme/backends/npfftw_backend.py +256 -514
  23. tme/backends/pytorch_backend.py +41 -154
  24. tme/density.py +312 -421
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +366 -303
  27. tme/matching_exhaustive.py +279 -1521
  28. tme/matching_optimization.py +234 -129
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +281 -387
  31. tme/memory.py +377 -0
  32. tme/orientations.py +226 -66
  33. tme/parser.py +3 -4
  34. tme/preprocessing/__init__.py +2 -0
  35. tme/preprocessing/_utils.py +217 -0
  36. tme/preprocessing/composable_filter.py +31 -0
  37. tme/preprocessing/compose.py +55 -0
  38. tme/preprocessing/frequency_filters.py +388 -0
  39. tme/preprocessing/tilt_series.py +1011 -0
  40. tme/preprocessor.py +574 -530
  41. tme/structure.py +495 -189
  42. tme/types.py +5 -3
  43. pytme-0.2.0b0.data/scripts/match_template.py +0 -800
  44. pytme-0.2.0b0.dist-info/METADATA +0 -73
  45. pytme-0.2.0b0.dist-info/RECORD +0 -66
  46. tme/helpers.py +0 -881
  47. tme/matching_constrained.py +0 -195
  48. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  49. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  50. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- """ Backend Apple's MLX library for template matching.
1
+ """ Backend using Apple's MLX library for template matching.
2
2
 
3
3
  Copyright (c) 2024 European Molecular Biology Laboratory
4
4
 
@@ -9,34 +9,38 @@ from typing import Tuple, List, Callable
9
9
  import numpy as np
10
10
 
11
11
  from .npfftw_backend import NumpyFFTWBackend
12
- from ..types import NDArray, MlxArray, Scalar
12
+ from ..types import NDArray, MlxArray, Scalar, shm_type
13
13
 
14
14
 
15
15
  class MLXBackend(NumpyFFTWBackend):
16
16
  """
17
- A MLX based backend for template matching.
17
+ A mlx-based matching backend.
18
18
  """
19
19
 
20
20
  def __init__(
21
21
  self,
22
22
  device="cpu",
23
- default_dtype=None,
23
+ float_dtype=None,
24
24
  complex_dtype=None,
25
- default_dtype_int=None,
25
+ int_dtype=None,
26
+ overflow_safe_dtype=None,
26
27
  **kwargs,
27
28
  ):
28
29
  import mlx.core as mx
29
30
 
30
31
  device = mx.cpu if device == "cpu" else mx.gpu
31
- default_dtype = mx.float32 if default_dtype is None else default_dtype
32
+ float_dtype = mx.float32 if float_dtype is None else float_dtype
32
33
  complex_dtype = mx.complex64 if complex_dtype is None else complex_dtype
33
- default_dtype_int = mx.int32 if default_dtype_int is None else default_dtype_int
34
+ int_dtype = mx.int32 if int_dtype is None else int_dtype
35
+ if overflow_safe_dtype is None:
36
+ overflow_safe_dtype = mx.float32
34
37
 
35
38
  super().__init__(
36
39
  array_backend=mx,
37
- default_dtype=default_dtype,
40
+ float_dtype=float_dtype,
38
41
  complex_dtype=complex_dtype,
39
- default_dtype_int=default_dtype_int,
42
+ int_dtype=int_dtype,
43
+ overflow_safe_dtype=overflow_safe_dtype,
40
44
  )
41
45
 
42
46
  self.device = device
@@ -68,6 +72,15 @@ class MLXBackend(NumpyFFTWBackend):
68
72
  return None
69
73
  return self._array_backend.add(x1, x2, **kwargs)
70
74
 
75
+ def multiply(self, x1, x2, out: MlxArray = None, **kwargs) -> MlxArray:
76
+ x1 = self.to_backend_array(x1)
77
+ x2 = self.to_backend_array(x2)
78
+
79
+ if out is not None:
80
+ out[:] = self._array_backend.multiply(x1, x2, **kwargs)
81
+ return None
82
+ return self._array_backend.multiply(x1, x2, **kwargs)
83
+
71
84
  def std(self, arr: MlxArray, axis) -> Scalar:
72
85
  return self._array_backend.sqrt(arr.var(axis=axis))
73
86
 
@@ -80,30 +93,12 @@ class MLXBackend(NumpyFFTWBackend):
80
93
  def tobytes(self, arr):
81
94
  return self.to_numpy_array(arr).tobytes()
82
95
 
83
- def preallocate_array(self, shape: Tuple[int], dtype: type = None) -> NDArray:
84
- """
85
- Returns a byte-aligned array of zeros with specified shape and dtype.
86
-
87
- Parameters
88
- ----------
89
- shape : Tuple[int]
90
- Desired shape for the array.
91
- dtype : type, optional
92
- Desired data type for the array.
93
-
94
- Returns
95
- -------
96
- NDArray
97
- Byte-aligned array of zeros with specified shape and dtype.
98
- """
99
- arr = self._array_backend.zeros(shape, dtype=dtype)
100
- return arr
101
-
102
96
  def full(self, shape, fill_value, dtype=None):
103
97
  return self._array_backend.full(shape=shape, dtype=dtype, vals=fill_value)
104
98
 
105
- def fill(self, arr: MlxArray, value: Scalar) -> None:
99
+ def fill(self, arr: MlxArray, value: Scalar) -> MlxArray:
106
100
  arr[:] = value
101
+ return arr
107
102
 
108
103
  def zeros(self, shape: Tuple[int], dtype: type = None) -> MlxArray:
109
104
  return self._array_backend.zeros(shape=shape, dtype=dtype)
@@ -141,8 +136,8 @@ class MLXBackend(NumpyFFTWBackend):
141
136
  new_shape = self.to_backend_array(newshape)
142
137
  current_shape = self.to_backend_array(arr.shape)
143
138
  starts = self.subtract(current_shape, new_shape)
144
- starts = self.astype(self.divide(starts, 2), self._default_dtype_int)
145
- stops = self.astype(self.add(starts, newshape), self._default_dtype_int)
139
+ starts = self.astype(self.divide(starts, 2), self._int_dtype)
140
+ stops = self.astype(self.add(starts, newshape), self._int_dtype)
146
141
  starts, stops = starts.tolist(), stops.tolist()
147
142
  box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
148
143
  return arr[box]
@@ -185,13 +180,11 @@ class MLXBackend(NumpyFFTWBackend):
185
180
 
186
181
  return rfftn, irfftn
187
182
 
188
- def sharedarr_to_arr(
189
- self, shape: Tuple[int], dtype: str, shm: MlxArray
190
- ) -> MlxArray:
191
- return shm
183
+ def from_sharedarr(self, arr: MlxArray) -> MlxArray:
184
+ return arr
192
185
 
193
186
  @staticmethod
194
- def arr_to_sharedarr(arr: MlxArray, shared_memory_handler: type = None) -> MlxArray:
187
+ def to_sharedarr(arr: MlxArray, shared_memory_handler: type = None) -> shm_type:
195
188
  return arr
196
189
 
197
190
  def topk_indices(self, arr: NDArray, k: int):
@@ -200,7 +193,7 @@ class MLXBackend(NumpyFFTWBackend):
200
193
  ret = [self.to_backend_array(x) for x in ret]
201
194
  return ret
202
195
 
203
- def rotate_array(
196
+ def rigid_transform(
204
197
  self,
205
198
  arr: NDArray,
206
199
  rotation_matrix: NDArray,
@@ -210,10 +203,8 @@ class MLXBackend(NumpyFFTWBackend):
210
203
  out: NDArray = None,
211
204
  out_mask: NDArray = None,
212
205
  order: int = 3,
206
+ **kwargs,
213
207
  ) -> None:
214
- rotate_mask = arr_mask is not None
215
- return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
216
-
217
208
  arr = self.to_numpy_array(arr)
218
209
  rotation_matrix = self.to_numpy_array(rotation_matrix)
219
210
 
@@ -223,46 +214,32 @@ class MLXBackend(NumpyFFTWBackend):
223
214
  if translation is not None:
224
215
  translation = self.to_numpy_array(translation)
225
216
 
226
- out_pass, out_mask_pass = None, None
227
- if out is not None:
228
- out_pass = self.to_numpy_array(out)
229
- if out_mask is not None:
230
- out_mask_pass = self.to_numpy_array(out_mask)
217
+ if out is None:
218
+ out = self.zeros(arr.shape)
219
+ if out_mask is None and arr_mask is not None:
220
+ out_mask_pass = self.zeros(arr_mask.shape)
231
221
 
232
- ret = NumpyFFTWBackend().rotate_array(
222
+ ret = NumpyFFTWBackend().rigid_transform(
233
223
  arr=arr,
234
224
  rotation_matrix=rotation_matrix,
235
225
  arr_mask=arr_mask,
236
226
  translation=translation,
237
227
  use_geometric_center=use_geometric_center,
238
- out=out_pass,
239
- out_mask=out_mask_pass,
240
228
  order=order,
241
229
  )
242
230
 
243
- if ret is not None:
244
- if len(ret) == 1 and out is None:
245
- out_pass = ret
246
- elif len(ret) == 1 and out_mask is None:
247
- out_mask_pass = ret
248
- else:
249
- out_pass, out_mask_pass = ret
231
+ out_pass, out_mask_pass = ret
232
+ out[:] = self.to_backend_array(out_pass)
250
233
 
251
- if out is not None:
252
- out[:] = self.to_backend_array(out_pass)
234
+ if out_mask_pass is not None:
235
+ out_mask_pass = self.to_backend_array(out_mask_pass)
253
236
 
254
237
  if out_mask is not None:
255
- out_mask[:] = self.to_backend_array(out_mask_pass)
256
-
257
- match return_type:
258
- case 0:
259
- return None
260
- case 1:
261
- return out
262
- case 2:
263
- return out_mask
264
- case 3:
265
- return out, out_mask
238
+ out_mask[:] = out_mask_pass
239
+ else:
240
+ out_mask = out_mask_pass
241
+
242
+ return out, out_mask
266
243
 
267
244
  def indices(self, arr: List) -> MlxArray:
268
245
  ret = NumpyFFTWBackend().indices(arr)