pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/match_template.py +219 -216
  2. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/postprocess.py +86 -54
  3. pytme-0.2.3.data/scripts/preprocess.py +132 -0
  4. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +181 -94
  5. pytme-0.2.3.dist-info/METADATA +92 -0
  6. pytme-0.2.3.dist-info/RECORD +75 -0
  7. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/WHEEL +1 -1
  8. pytme-0.2.1.data/scripts/preprocess.py → scripts/eval.py +1 -1
  9. scripts/extract_candidates.py +20 -13
  10. scripts/match_template.py +219 -216
  11. scripts/match_template_filters.py +154 -95
  12. scripts/postprocess.py +86 -54
  13. scripts/preprocess.py +95 -56
  14. scripts/preprocessor_gui.py +181 -94
  15. scripts/refine_matches.py +265 -61
  16. tme/__init__.py +0 -1
  17. tme/__version__.py +1 -1
  18. tme/analyzer.py +458 -813
  19. tme/backends/__init__.py +40 -11
  20. tme/backends/_jax_utils.py +187 -0
  21. tme/backends/cupy_backend.py +109 -226
  22. tme/backends/jax_backend.py +230 -152
  23. tme/backends/matching_backend.py +445 -384
  24. tme/backends/mlx_backend.py +32 -59
  25. tme/backends/npfftw_backend.py +240 -507
  26. tme/backends/pytorch_backend.py +30 -151
  27. tme/density.py +248 -371
  28. tme/extensions.cpython-311-darwin.so +0 -0
  29. tme/matching_data.py +328 -284
  30. tme/matching_exhaustive.py +195 -1499
  31. tme/matching_optimization.py +143 -106
  32. tme/matching_scores.py +887 -0
  33. tme/matching_utils.py +287 -388
  34. tme/memory.py +377 -0
  35. tme/orientations.py +78 -21
  36. tme/parser.py +3 -4
  37. tme/preprocessing/_utils.py +61 -32
  38. tme/preprocessing/composable_filter.py +7 -4
  39. tme/preprocessing/compose.py +7 -3
  40. tme/preprocessing/frequency_filters.py +49 -39
  41. tme/preprocessing/tilt_series.py +44 -72
  42. tme/preprocessor.py +560 -526
  43. tme/structure.py +491 -188
  44. tme/types.py +5 -3
  45. pytme-0.2.1.dist-info/METADATA +0 -73
  46. pytme-0.2.1.dist-info/RECORD +0 -73
  47. tme/helpers.py +0 -881
  48. tme/matching_constrained.py +0 -195
  49. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
  50. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,38 @@
1
1
  """ Backend using jax for template matching.
2
2
 
3
- Copyright (c) 2023 European Molecular Biology Laboratory
3
+ Copyright (c) 2023-2024 European Molecular Biology Laboratory
4
4
 
5
5
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
6
  """
7
- from typing import Tuple, Callable
7
+ from functools import wraps
8
+ from typing import Tuple, List, Callable
9
+
10
+ from ..types import BackendArray
11
+ from .npfftw_backend import NumpyFFTWBackend, shm_type
12
+
13
+
14
+ def emulate_out(func):
15
+ """
16
+ Adds an out argument to write output of ``func`` to.
17
+ """
18
+
19
+ @wraps(func)
20
+ def inner(*args, out=None, **kwargs):
21
+ ret = func(*args, **kwargs)
22
+ if out is not None:
23
+ out = out.at[:].set(ret)
24
+ return out
25
+ return ret
26
+
27
+ return inner
8
28
 
9
- from .npfftw_backend import NumpyFFTWBackend
10
29
 
11
30
  class JaxBackend(NumpyFFTWBackend):
12
- def __init__(
13
- self, float_dtype=None, complex_dtype=None, int_dtype=None, **kwargs
14
- ):
31
+ """
32
+ A jax-based matching backend.
33
+ """
34
+
35
+ def __init__(self, float_dtype=None, complex_dtype=None, int_dtype=None, **kwargs):
15
36
  import jax.scipy as jsp
16
37
  import jax.numpy as jnp
17
38
 
@@ -19,24 +40,33 @@ class JaxBackend(NumpyFFTWBackend):
19
40
  complex_dtype = jnp.complex64 if complex_dtype is None else complex_dtype
20
41
  int_dtype = jnp.int32 if int_dtype is None else int_dtype
21
42
 
22
- self.scipy = jsp
23
43
  super().__init__(
24
44
  array_backend=jnp,
25
45
  float_dtype=float_dtype,
26
46
  complex_dtype=complex_dtype,
27
47
  int_dtype=int_dtype,
48
+ overflow_safe_dtype=float_dtype,
28
49
  )
50
+ self.scipy = jsp
51
+ self._create_ufuncs()
52
+ try:
53
+ from ._jax_utils import scan as _
54
+
55
+ self.scan = self._scan
56
+ except Exception:
57
+ pass
29
58
 
30
- def to_backend_array(self, arr):
31
- return self._array_backend.asarray(arr)
59
+ def from_sharedarr(self, arr: BackendArray) -> BackendArray:
60
+ return arr
32
61
 
33
- def preallocate_array(self, shape: Tuple[int], dtype: type):
34
- arr = self._array_backend.zeros(shape, dtype=dtype)
62
+ @staticmethod
63
+ def to_sharedarr(arr: BackendArray, shared_memory_handler: type = None) -> shm_type:
35
64
  return arr
36
65
 
37
- def topleft_pad(self, arr, shape: Tuple[int], padval: int = 0):
38
- b = self.preallocate_array(shape, arr.dtype)
39
- self.add(b, padval, out=b)
66
+ def topleft_pad(
67
+ self, arr: BackendArray, shape: Tuple[int], padval: int = 0
68
+ ) -> BackendArray:
69
+ b = self.full(shape=shape, dtype=arr.dtype, fill_value=padval)
40
70
  aind = [slice(None, None)] * arr.ndim
41
71
  bind = [slice(None, None)] * arr.ndim
42
72
  for i in range(arr.ndim):
@@ -47,43 +77,29 @@ class JaxBackend(NumpyFFTWBackend):
47
77
  b = b.at[tuple(bind)].set(arr[tuple(aind)])
48
78
  return b
49
79
 
50
-
51
- def add(self, x1, x2, out = None, *args, **kwargs):
52
- x1 = self.to_backend_array(x1)
53
- x2 = self.to_backend_array(x2)
54
- ret = self._array_backend.add(x1, x2, *args, **kwargs)
55
-
56
- if out is not None:
57
- out = out.at[:].set(ret)
58
- return ret
59
-
60
- def subtract(self, x1, x2, out = None, *args, **kwargs):
61
- x1 = self.to_backend_array(x1)
62
- x2 = self.to_backend_array(x2)
63
- ret = self._array_backend.subtract(x1, x2, *args, **kwargs)
64
- if out is not None:
65
- out = out.at[:].set(ret)
66
- return ret
67
-
68
- def multiply(self, x1, x2, out = None, *args, **kwargs):
69
- x1 = self.to_backend_array(x1)
70
- x2 = self.to_backend_array(x2)
71
- ret = self._array_backend.multiply(x1, x2, *args, **kwargs)
72
- if out is not None:
73
- out = out.at[:].set(ret)
74
- return ret
75
-
76
- def divide(self, x1, x2, out = None, *args, **kwargs):
77
- x1 = self.to_backend_array(x1)
78
- x2 = self.to_backend_array(x2)
79
- ret = self._array_backend.divide(x1, x2, *args, **kwargs)
80
- if out is not None:
81
- out = out.at[:].set(ret)
82
- return ret
83
-
84
- def fill(self, arr, value: float) -> None:
85
- arr.at[:].set(value)
86
-
80
+ def _create_ufuncs(self):
81
+ ufuncs = [
82
+ "add",
83
+ "subtract",
84
+ "multiply",
85
+ "divide",
86
+ "square",
87
+ "sqrt",
88
+ "maximum",
89
+ ]
90
+ for ufunc in ufuncs:
91
+ backend_method = emulate_out(getattr(self._array_backend, ufunc))
92
+ setattr(self, ufunc, staticmethod(backend_method))
93
+
94
+ ufuncs = ["zeros", "full"]
95
+ for ufunc in ufuncs:
96
+ backend_method = getattr(self._array_backend, ufunc)
97
+ setattr(self, ufunc, staticmethod(backend_method))
98
+
99
+ def fill(self, arr: BackendArray, value: float) -> BackendArray:
100
+ return self._array_backend.full(
101
+ shape=arr.shape, dtype=arr.dtype, fill_value=value
102
+ )
87
103
 
88
104
  def build_fft(
89
105
  self,
@@ -92,127 +108,189 @@ class JaxBackend(NumpyFFTWBackend):
92
108
  inverse_fast_shape: Tuple[int] = None,
93
109
  **kwargs,
94
110
  ) -> Tuple[Callable, Callable]:
95
- """
96
- Build fft builder functions.
97
-
98
- Parameters
99
- ----------
100
- fast_shape : tuple
101
- Tuple of integers corresponding to fast convolution shape
102
- (see :py:meth:`PytorchBackend.compute_convolution_shapes`).
103
- fast_ft_shape : tuple
104
- Tuple of integers corresponding to the shape of the Fourier
105
- transform array (see :py:meth:`PytorchBackend.compute_convolution_shapes`).
106
- inverse_fast_shape : tuple, optional
107
- Output shape of the inverse Fourier transform. By default fast_shape.
108
- **kwargs : dict, optional
109
- Unused keyword arguments.
110
-
111
- Returns
112
- -------
113
- tuple
114
- Tupple containing callable rfft and irfft object.
115
- """
116
111
  if inverse_fast_shape is None:
117
112
  inverse_fast_shape = fast_shape
118
113
 
119
- def rfftn(
120
- arr, out, shape: Tuple[int] = fast_shape
121
- ) -> None:
122
- out = out.at[:].set(self._array_backend.fft.rfftn(arr, s=shape))
114
+ def rfftn(arr, out, shape=fast_shape):
115
+ return self._array_backend.fft.rfftn(arr, s=shape)
123
116
 
124
- def irfftn(
125
- arr, out, shape: Tuple[int] = inverse_fast_shape
126
- ) -> None:
127
- out = out.at[:].set(self._array_backend.fft.irfftn(arr, s=shape))
117
+ def irfftn(arr, out, shape=fast_shape):
118
+ return self._array_backend.fft.irfftn(arr, s=shape)
128
119
 
129
120
  return rfftn, irfftn
130
121
 
131
- def sharedarr_to_arr(self, shm, shape: Tuple[int], dtype: str):
132
- return shm
133
-
134
- @staticmethod
135
- def arr_to_sharedarr(arr, shared_memory_handler: type = None):
136
- return arr
137
-
138
- def rotate_array(
122
+ def rigid_transform(
139
123
  self,
140
- arr,
141
- rotation_matrix,
142
- arr_mask = None,
143
- translation = None,
144
- use_geometric_center: bool = False,
145
- out = None,
146
- out_mask = None,
147
- order: int = 3,
148
- ) -> None:
124
+ arr: BackendArray,
125
+ rotation_matrix: BackendArray,
126
+ out: BackendArray = None,
127
+ out_mask: BackendArray = None,
128
+ translation: BackendArray = None,
129
+ arr_mask: BackendArray = None,
130
+ order: int = 1,
131
+ **kwargs,
132
+ ) -> Tuple[BackendArray, BackendArray]:
149
133
  rotate_mask = arr_mask is not None
150
- return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
151
- translation = self.zeros(arr.ndim) if translation is None else translation
152
134
 
153
- indices = self._array_backend.indices(arr.shape).reshape(
154
- (len(arr.shape), -1)
155
- ).astype(self._float_dtype)
156
-
157
- center = self.divide(arr.shape, 2)
158
- if not use_geometric_center:
159
- center = self.center_of_mass(arr, cutoff=0)
160
- center = center[:, None]
135
+ center = self.divide(self.to_backend_array(arr.shape) - 1, 2)[:, None]
136
+ indices = self._array_backend.indices(arr.shape, dtype=self._float_dtype)
137
+ indices = indices.reshape((arr.ndim, -1))
161
138
  indices = indices.at[:].add(-center)
162
- rotation_matrix = self._array_backend.linalg.inv(rotation_matrix)
163
- indices = self._array_backend.matmul(rotation_matrix, indices)
139
+ indices = self._array_backend.matmul(rotation_matrix.T, indices)
164
140
  indices = indices.at[:].add(center)
141
+ if translation is not None:
142
+ indices = indices.at[:].add(translation)
165
143
 
166
- out = self.zeros_like(arr) if out is None else out
167
- out_slice = tuple(slice(0, stop) for stop in arr.shape)
168
-
169
- out = out.at[out_slice].set(
170
- self.scipy.ndimage.map_coordinates(
171
- arr, indices, order=order
172
- ).reshape(arr.shape)
144
+ out = self.scipy.ndimage.map_coordinates(arr, indices, order=order).reshape(
145
+ arr.shape
173
146
  )
174
147
 
148
+ out_mask = arr_mask
175
149
  if rotate_mask:
176
- out_mask = self.zeros_like(arr_mask) if out_mask is None else out_mask
177
- out_mask_slice = tuple(slice(0, stop) for stop in arr_mask.shape)
178
- out_mask = out_mask.at[out_mask_slice].set(
179
- self.scipy.ndimage.map_coordinates(
180
- arr_mask, indices, order=order
181
- ).reshape(arr.shape)
182
- )
150
+ out_mask = self.scipy.ndimage.map_coordinates(
151
+ arr_mask, indices, order=order
152
+ ).reshape(arr_mask.shape)
183
153
 
184
- match return_type:
185
- case 0:
186
- return None
187
- case 1:
188
- return out
189
- case 2:
190
- return out_mask
191
- case 3:
192
- return out, out_mask
154
+ return out, out_mask
193
155
 
194
156
  def max_score_over_rotations(
195
157
  self,
196
- score_space,
197
- internal_scores,
198
- internal_rotations,
158
+ scores: BackendArray,
159
+ max_scores: BackendArray,
160
+ rotations: BackendArray,
199
161
  rotation_index: int,
200
- ):
162
+ ) -> Tuple[BackendArray, BackendArray]:
163
+ update = self.greater(max_scores, scores)
164
+ max_scores = max_scores.at[:].set(self.where(update, max_scores, scores))
165
+ rotations = rotations.at[:].set(self.where(update, rotations, rotation_index))
166
+ return max_scores, rotations
167
+
168
+ def _scan(
169
+ self,
170
+ matching_data: type,
171
+ splits: Tuple[Tuple[slice, slice]],
172
+ n_jobs: int,
173
+ callback_class,
174
+ rotate_mask: bool = False,
175
+ **kwargs,
176
+ ) -> List:
201
177
  """
202
- Modify internal_scores and internal_rotations inplace with scores and rotation
203
- index respectively, wherever score_sapce is larger than internal scores.
204
-
205
- Parameters
206
- ----------
207
- score_space : CupyArray
208
- The score space to compare against internal_scores.
209
- internal_scores : CupyArray
210
- The internal scores to update with maximum scores.
211
- internal_rotations : CupyArray
212
- The internal rotations corresponding to the maximum scores.
213
- rotation_index : int
214
- The index representing the current rotation.
178
+ Emulates output of :py:meth:`tme.matching_exhaustive.scan` using
179
+ :py:class:`tme.analyzer.MaxScoreOverRotations`.
215
180
  """
216
- indices = score_space > internal_scores
217
- internal_scores.at[indices].set(score_space[indices])
218
- internal_rotations.at[indices].set(rotation_index)
181
+ from ._jax_utils import scan as scan_inner
182
+
183
+ pad_target = True if len(splits) > 1 else False
184
+ convolution_mode = "valid" if pad_target else "same"
185
+ target_pad = matching_data.target_padding(pad_target=pad_target)
186
+
187
+ target_shape = tuple(
188
+ (x.stop - x.start + p) for x, p in zip(splits[0][0], target_pad)
189
+ )
190
+ conv_shape, fast_shape, fast_ft_shape, shift = matching_data._fourier_padding(
191
+ target_shape=self.to_numpy_array(target_shape),
192
+ template_shape=self.to_numpy_array(matching_data._template.shape),
193
+ pad_fourier=False,
194
+ )
195
+
196
+ analyzer_args = {
197
+ "convolution_mode": convolution_mode,
198
+ "fourier_shift": shift,
199
+ "targetshape": target_shape,
200
+ "templateshape": matching_data.template.shape,
201
+ "convolution_shape": conv_shape,
202
+ }
203
+
204
+ create_target_filter = matching_data.target_filter is not None
205
+ create_template_filter = matching_data.template_filter is not None
206
+ create_filter = create_target_filter or create_template_filter
207
+
208
+ # Applying the filter leads to more FFTs
209
+ fastt_shape = matching_data._template.shape
210
+ if create_template_filter:
211
+ _, fastt_shape, _, tshift = matching_data._fourier_padding(
212
+ target_shape=self.to_numpy_array(matching_data._template.shape),
213
+ template_shape=self.to_numpy_array(
214
+ [1 for _ in matching_data._template.shape]
215
+ ),
216
+ pad_fourier=False,
217
+ )
218
+
219
+ ret, template_filter, target_filter = [], 1, 1
220
+ rotation_mapping = {
221
+ self.tobytes(matching_data.rotations[i]): i
222
+ for i in range(matching_data.rotations.shape[0])
223
+ }
224
+ for split_start in range(0, len(splits), n_jobs):
225
+ split_subset = splits[split_start : (split_start + n_jobs)]
226
+ if not len(split_subset):
227
+ continue
228
+
229
+ targets, translation_offsets = [], []
230
+ for target_split, template_split in split_subset:
231
+ base = matching_data.subset_by_slice(
232
+ target_slice=target_split,
233
+ target_pad=target_pad,
234
+ template_slice=template_split,
235
+ )
236
+ translation_offsets.append(base._translation_offset)
237
+ targets.append(self.topleft_pad(base._target, fast_shape))
238
+
239
+ if create_filter:
240
+ filter_args = {
241
+ "data_rfft": self.fft.rfftn(targets[0]),
242
+ "return_real_fourier": True,
243
+ "shape_is_real_fourier": False,
244
+ }
245
+
246
+ if create_template_filter:
247
+ template_filter = matching_data.template_filter(
248
+ shape=fastt_shape, **filter_args
249
+ )["data"]
250
+ template_filter = template_filter.at[(0,) * template_filter.ndim].set(0)
251
+
252
+ if create_target_filter:
253
+ target_filter = matching_data.target_filter(
254
+ shape=fast_shape, **filter_args
255
+ )["data"]
256
+ target_filter = target_filter.at[(0,) * target_filter.ndim].set(0)
257
+
258
+ create_filter, create_template_filter, create_target_filter = (False,) * 3
259
+ base, targets = None, self._array_backend.stack(targets)
260
+ scores, rotations = scan_inner(
261
+ targets,
262
+ self.topleft_pad(matching_data.template, fastt_shape),
263
+ self.topleft_pad(matching_data.template_mask, fastt_shape),
264
+ matching_data.rotations,
265
+ template_filter,
266
+ target_filter,
267
+ fast_shape,
268
+ rotate_mask,
269
+ )
270
+
271
+ for index in range(scores.shape[0]):
272
+ temp = callback_class(
273
+ scores=scores[index],
274
+ rotations=rotations[index],
275
+ thread_safe=False,
276
+ offset=translation_offsets[index],
277
+ )
278
+ temp.rotation_mapping = rotation_mapping
279
+ ret.append(tuple(temp._postprocess(**analyzer_args)))
280
+
281
+ return ret
282
+
283
+ def get_available_memory(self) -> int:
284
+ import jax
285
+
286
+ _memory = {"cpu": 0, "gpu": 0}
287
+ for device in jax.devices():
288
+ if device.platform == "cpu":
289
+ _memory["cpu"] = super().get_available_memory()
290
+ else:
291
+ mem_stats = device.memory_stats()
292
+ _memory["gpu"] += mem_stats.get("bytes_limit", 0)
293
+
294
+ if _memory["gpu"] > 0:
295
+ return _memory["gpu"]
296
+ return _memory["cpu"]