pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/match_template.py +219 -216
  2. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/postprocess.py +86 -54
  3. pytme-0.2.3.data/scripts/preprocess.py +132 -0
  4. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +181 -94
  5. pytme-0.2.3.dist-info/METADATA +92 -0
  6. pytme-0.2.3.dist-info/RECORD +75 -0
  7. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/WHEEL +1 -1
  8. pytme-0.2.1.data/scripts/preprocess.py → scripts/eval.py +1 -1
  9. scripts/extract_candidates.py +20 -13
  10. scripts/match_template.py +219 -216
  11. scripts/match_template_filters.py +154 -95
  12. scripts/postprocess.py +86 -54
  13. scripts/preprocess.py +95 -56
  14. scripts/preprocessor_gui.py +181 -94
  15. scripts/refine_matches.py +265 -61
  16. tme/__init__.py +0 -1
  17. tme/__version__.py +1 -1
  18. tme/analyzer.py +458 -813
  19. tme/backends/__init__.py +40 -11
  20. tme/backends/_jax_utils.py +187 -0
  21. tme/backends/cupy_backend.py +109 -226
  22. tme/backends/jax_backend.py +230 -152
  23. tme/backends/matching_backend.py +445 -384
  24. tme/backends/mlx_backend.py +32 -59
  25. tme/backends/npfftw_backend.py +240 -507
  26. tme/backends/pytorch_backend.py +30 -151
  27. tme/density.py +248 -371
  28. tme/extensions.cpython-311-darwin.so +0 -0
  29. tme/matching_data.py +328 -284
  30. tme/matching_exhaustive.py +195 -1499
  31. tme/matching_optimization.py +143 -106
  32. tme/matching_scores.py +887 -0
  33. tme/matching_utils.py +287 -388
  34. tme/memory.py +377 -0
  35. tme/orientations.py +78 -21
  36. tme/parser.py +3 -4
  37. tme/preprocessing/_utils.py +61 -32
  38. tme/preprocessing/composable_filter.py +7 -4
  39. tme/preprocessing/compose.py +7 -3
  40. tme/preprocessing/frequency_filters.py +49 -39
  41. tme/preprocessing/tilt_series.py +44 -72
  42. tme/preprocessor.py +560 -526
  43. tme/structure.py +491 -188
  44. tme/types.py +5 -3
  45. pytme-0.2.1.dist-info/METADATA +0 -73
  46. pytme-0.2.1.dist-info/RECORD +0 -73
  47. tme/helpers.py +0 -881
  48. tme/matching_constrained.py +0 -195
  49. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
  50. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
@@ -5,12 +5,13 @@
5
5
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
6
  """
7
7
 
8
- from typing import Tuple
8
+ from typing import Tuple, List
9
9
 
10
10
  import numpy as np
11
- from numpy.typing import NDArray
12
11
 
13
- from ..backends import backend
12
+ from ..backends import backend as be
13
+ from ..backends import NumpyFFTWBackend
14
+ from ..types import BackendArray, NDArray
14
15
  from ..matching_utils import euler_to_rotationmatrix
15
16
 
16
17
 
@@ -93,18 +94,27 @@ def frequency_grid_at_angle(
93
94
  tilt_shape = compute_tilt_shape(
94
95
  shape=shape, opening_axis=opening_axis, reduce_dim=False
95
96
  )
96
- index_grid = centered_grid(shape=tilt_shape)
97
+
98
+ if angle == 0:
99
+ index_grid = fftfreqn(
100
+ tuple(x for x in tilt_shape if x != 1),
101
+ sampling_rate=1,
102
+ compute_euclidean_norm=True,
103
+ )
104
+
97
105
  if angle != 0:
98
106
  angles = np.zeros(len(shape))
99
107
  angles[tilt_axis] = angle
100
108
  rotation_matrix = euler_to_rotationmatrix(np.roll(angles, opening_axis - 1))
109
+
110
+ index_grid = fftfreqn(tilt_shape, sampling_rate=None)
101
111
  index_grid = np.einsum("ij,j...->i...", rotation_matrix, index_grid)
112
+ norm = np.multiply(sampling_rate, shape).astype(int)
102
113
 
103
- norm = np.divide(1, 2 * sampling_rate * np.divide(shape, 2).astype(int))
114
+ index_grid = np.divide(index_grid.T, norm).T
115
+ index_grid = np.squeeze(index_grid)
116
+ index_grid = np.linalg.norm(index_grid, axis=(0))
104
117
 
105
- index_grid = np.multiply(index_grid.T, norm).T
106
- index_grid = np.squeeze(index_grid)
107
- index_grid = np.linalg.norm(index_grid, axis=(0))
108
118
  return index_grid
109
119
 
110
120
 
@@ -113,9 +123,10 @@ def fftfreqn(
113
123
  sampling_rate: Tuple[float],
114
124
  compute_euclidean_norm: bool = False,
115
125
  shape_is_real_fourier: bool = False,
126
+ return_sparse_grid: bool = False,
116
127
  ) -> NDArray:
117
128
  """
118
- Generate the n-dimensional discrete Fourier Transform sample frequencies.
129
+ Generate the n-dimensional discrete Fourier transform sample frequencies.
119
130
 
120
131
  Parameters:
121
132
  -----------
@@ -133,56 +144,74 @@ def fftfreqn(
133
144
  NDArray
134
145
  The sample frequencies.
135
146
  """
136
- center = backend.astype(backend.divide(shape, 2), backend._int_dtype)
137
-
138
- norm = np.ones(len(shape))
147
+ # There is no real need to have these operations on GPU right now
148
+ temp_backend = NumpyFFTWBackend()
149
+ norm = temp_backend.full(len(shape), fill_value=1)
150
+ center = temp_backend.astype(temp_backend.divide(shape, 2), temp_backend._int_dtype)
139
151
  if sampling_rate is not None:
140
- norm = backend.astype(backend.multiply(shape, sampling_rate), int)
152
+ norm = temp_backend.astype(temp_backend.multiply(shape, sampling_rate), int)
141
153
 
142
154
  if shape_is_real_fourier:
143
- center[-1] = 0
144
- norm[-1] = 1
155
+ center[-1], norm[-1] = 0, 1
145
156
  if sampling_rate is not None:
146
157
  norm[-1] = (shape[-1] - 1) * 2 * sampling_rate
147
158
 
148
- indices = backend.transpose(backend.indices(shape))
149
- indices -= center
150
- indices = backend.divide(indices, norm)
151
- indices = backend.transpose(indices)
159
+ grids = []
160
+ for i, x in enumerate(shape):
161
+ baseline_dims = tuple(1 if i != t else x for t in range(len(shape)))
162
+ grid = (temp_backend.arange(x) - center[i]) / norm[i]
163
+ grids.append(temp_backend.reshape(grid, baseline_dims))
152
164
 
153
165
  if compute_euclidean_norm:
154
- indices = backend.square(indices)
155
- indices = backend.sum(indices, axis=0)
156
- backend.sqrt(indices, out=indices)
166
+ grids = sum(temp_backend.square(x) for x in grids)
167
+ grids = temp_backend.sqrt(grids, out=grids)
168
+ return grids
169
+
170
+ if return_sparse_grid:
171
+ return grids
157
172
 
158
- return indices
173
+ grid_flesh = temp_backend.full(shape, fill_value=1)
174
+ grids = temp_backend.stack(tuple(grid * grid_flesh for grid in grids))
159
175
 
176
+ return grids
160
177
 
161
- def crop_real_fourier(data: NDArray) -> NDArray:
178
+
179
+ def crop_real_fourier(data: BackendArray) -> BackendArray:
162
180
  """
163
181
  Crop the real part of a Fourier transform.
164
182
 
165
183
  Parameters:
166
184
  -----------
167
- data : NDArray
185
+ data : BackendArray
168
186
  The Fourier transformed data.
169
187
 
170
188
  Returns:
171
189
  --------
172
- NDArray
190
+ BackendArray
173
191
  The cropped data.
174
192
  """
175
193
  stop = 1 + (data.shape[-1] // 2)
176
194
  return data[..., :stop]
177
195
 
178
196
 
179
- def shift_fourier(data: NDArray, shape_is_real_fourier: bool = False):
180
- shift = backend.add(
181
- backend.astype(backend.divide(data.shape, 2), int),
182
- backend.mod(data.shape, 2),
183
- )
197
+ def compute_fourier_shape(
198
+ shape: Tuple[int], shape_is_real_fourier: bool = False
199
+ ) -> List[int]:
200
+ if shape_is_real_fourier:
201
+ return shape
202
+ shape = [int(x) for x in shape]
203
+ shape[-1] = 1 + shape[-1] // 2
204
+ return shape
205
+
206
+
207
+ def shift_fourier(
208
+ data: BackendArray, shape_is_real_fourier: bool = False
209
+ ) -> BackendArray:
210
+ shape = be.to_backend_array(data.shape)
211
+ shift = be.add(be.divide(shape, 2), be.mod(shape, 2))
212
+ shift = [int(x) for x in shift]
184
213
  if shape_is_real_fourier:
185
214
  shift[-1] = 0
186
215
 
187
- data = backend.roll(data, shift, tuple(i for i in range(len(shift))))
216
+ data = be.roll(data, shift, tuple(i for i in range(len(shift))))
188
217
  return data
@@ -17,15 +17,18 @@ class ComposableFilter(ABC):
17
17
  @abstractmethod
18
18
  def __call__(self, *args, **kwargs) -> Dict:
19
19
  """
20
- Parameters:
21
- -----------
20
+
21
+ Parameters
22
+ ----------
23
+
22
24
  *args : tuple
23
25
  Variable length argument list.
24
26
  **kwargs : dict
25
27
  Arbitrary keyword arguments.
26
28
 
27
- Returns:
28
- --------
29
+ Returns
30
+ -------
31
+
29
32
  Dict
30
33
  A dictionary representing the result of the filtering operation.
31
34
  """
@@ -7,7 +7,7 @@
7
7
 
8
8
  from typing import Tuple, Dict
9
9
 
10
- from tme.backends import backend
10
+ from tme.backends import backend as be
11
11
 
12
12
 
13
13
  class Compose:
@@ -42,9 +42,13 @@ class Compose:
42
42
  kwargs.update(meta)
43
43
  ret = transform(**kwargs)
44
44
 
45
+ if "data" not in ret:
46
+ continue
47
+
45
48
  if ret.get("is_multiplicative_filter", False):
46
- backend.multiply(ret["data"], meta["data"], out=ret["data"])
47
- ret["merge"] = None
49
+ prev_data = meta.pop("data")
50
+ ret["data"] = be.multiply(ret["data"], prev_data, out=ret["data"])
51
+ ret["merge"], prev_data = None, None
48
52
 
49
53
  meta = ret
50
54
 
@@ -8,12 +8,12 @@ from math import log, sqrt
8
8
  from typing import Tuple, Dict
9
9
 
10
10
  import numpy as np
11
- from numpy.typing import NDArray
12
11
  from scipy.ndimage import mean as ndimean
13
12
  from scipy.ndimage import map_coordinates
14
13
 
15
- from ._utils import fftfreqn, crop_real_fourier, shift_fourier
16
- from ..backends import backend
14
+ from ..types import BackendArray
15
+ from ..backends import backend as be
16
+ from ._utils import fftfreqn, crop_real_fourier, shift_fourier, compute_fourier_shape
17
17
 
18
18
 
19
19
  class BandPassFilter:
@@ -29,7 +29,7 @@ class BandPassFilter:
29
29
  highpass : float, optional
30
30
  The highpass cutoff, defaults to None.
31
31
  sampling_rate : Tuple[float], optional
32
- The sampling rate in Fourier space, defaults to 1.
32
+ The sampling r_position_to_molmapate in Fourier space, defaults to 1.
33
33
  use_gaussian : bool, optional
34
34
  Whether to use Gaussian bandpass filter, defaults to True.
35
35
  return_real_fourier : bool, optional
@@ -63,7 +63,7 @@ class BandPassFilter:
63
63
  return_real_fourier: bool = False,
64
64
  shape_is_real_fourier: bool = False,
65
65
  **kwargs,
66
- ) -> NDArray:
66
+ ) -> BackendArray:
67
67
  """
68
68
  Generate a bandpass filter using discrete frequency cutoffs.
69
69
 
@@ -86,7 +86,7 @@ class BandPassFilter:
86
86
 
87
87
  Returns:
88
88
  --------
89
- NDArray
89
+ BackendArray
90
90
  The bandpass filter in Fourier space.
91
91
  """
92
92
  if shape_is_real_fourier:
@@ -127,7 +127,7 @@ class BandPassFilter:
127
127
  return_real_fourier: bool = False,
128
128
  shape_is_real_fourier: bool = False,
129
129
  **kwargs,
130
- ) -> NDArray:
130
+ ) -> BackendArray:
131
131
  """
132
132
  Generate a bandpass filter using Gaussian functions.
133
133
 
@@ -150,7 +150,7 @@ class BandPassFilter:
150
150
 
151
151
  Returns:
152
152
  --------
153
- NDArray
153
+ BackendArray
154
154
  The bandpass filter in Fourier space.
155
155
  """
156
156
  if shape_is_real_fourier:
@@ -162,29 +162,32 @@ class BandPassFilter:
162
162
  shape_is_real_fourier=shape_is_real_fourier,
163
163
  compute_euclidean_norm=True,
164
164
  )
165
- grid = -backend.square(grid)
165
+ grid = be.to_backend_array(grid)
166
+ grid = -be.square(grid)
166
167
 
167
168
  lowpass_filter, highpass_filter = 1, 1
168
169
  norm = float(sqrt(2 * log(2)))
169
- upper_sampling = float(backend.max(backend.multiply(2, sampling_rate)))
170
+ upper_sampling = float(
171
+ be.max(be.multiply(2, be.to_backend_array(sampling_rate)))
172
+ )
170
173
 
171
174
  if lowpass is not None:
172
175
  lowpass = float(lowpass)
173
- lowpass = backend.maximum(lowpass, backend.eps(backend._float_dtype))
176
+ lowpass = be.maximum(lowpass, be.eps(be._float_dtype))
174
177
  if highpass is not None:
175
178
  highpass = float(highpass)
176
- highpass = backend.maximum(highpass, backend.eps(backend._float_dtype))
179
+ highpass = be.maximum(highpass, be.eps(be._float_dtype))
177
180
 
178
181
  if lowpass is not None:
179
182
  lowpass = upper_sampling / (lowpass * norm)
180
- lowpass = backend.multiply(2, backend.square(lowpass))
181
- lowpass_filter = backend.exp(backend.divide(grid, lowpass))
183
+ lowpass = be.multiply(2, be.square(lowpass))
184
+ lowpass_filter = be.exp(be.divide(grid, lowpass))
182
185
  if highpass is not None:
183
186
  highpass = upper_sampling / (highpass * norm)
184
- highpass = backend.multiply(2, backend.square(highpass))
185
- highpass_filter = 1 - backend.exp(backend.divide(grid, highpass))
187
+ highpass = be.multiply(2, be.square(highpass))
188
+ highpass_filter = 1 - be.exp(be.divide(grid, highpass))
186
189
 
187
- bandpass_filter = backend.multiply(lowpass_filter, highpass_filter)
190
+ bandpass_filter = be.multiply(lowpass_filter, highpass_filter)
188
191
  bandpass_filter = shift_fourier(
189
192
  data=bandpass_filter, shape_is_real_fourier=shape_is_real_fourier
190
193
  )
@@ -205,7 +208,7 @@ class BandPassFilter:
205
208
  mask = func(**func_args)
206
209
 
207
210
  return {
208
- "data": backend.to_backend_array(mask),
211
+ "data": be.to_backend_array(mask),
209
212
  "sampling_rate": func_args.get("sampling_rate", 1),
210
213
  "is_multiplicative_filter": True,
211
214
  }
@@ -237,14 +240,14 @@ class LinearWhiteningFilter:
237
240
 
238
241
  @staticmethod
239
242
  def _compute_spectrum(
240
- data_rfft: NDArray, n_bins: int = None, batch_dimension: int = None
241
- ) -> Tuple[NDArray, NDArray]:
243
+ data_rfft: BackendArray, n_bins: int = None, batch_dimension: int = None
244
+ ) -> Tuple[BackendArray, BackendArray]:
242
245
  """
243
246
  Compute the spectrum of the input data.
244
247
 
245
248
  Parameters:
246
249
  -----------
247
- data_rfft : NDArray
250
+ data_rfft : BackendArray
248
251
  The Fourier transform of the input data.
249
252
  n_bins : int, optional
250
253
  The number of bins for computing the spectrum, defaults to None.
@@ -253,9 +256,9 @@ class LinearWhiteningFilter:
253
256
 
254
257
  Returns:
255
258
  --------
256
- bins : NDArray
259
+ bins : BackendArray
257
260
  Array containing the bin indices for the spectrum.
258
- radial_averages : NDArray
261
+ radial_averages : BackendArray
259
262
  Array containing the radial averages of the spectrum.
260
263
  """
261
264
  shape = tuple(x for i, x in enumerate(data_rfft.shape) if i != batch_dimension)
@@ -270,7 +273,7 @@ class LinearWhiteningFilter:
270
273
  shape_is_real_fourier=True,
271
274
  compute_euclidean_norm=True,
272
275
  )
273
- bins = backend.to_numpy_array(bins)
276
+ bins = be.to_numpy_array(bins)
274
277
 
275
278
  # Implicit lowpass to nyquist
276
279
  bins = np.floor(bins * (n_bins - 1) + 0.5).astype(int)
@@ -292,11 +295,11 @@ class LinearWhiteningFilter:
292
295
 
293
296
  @staticmethod
294
297
  def _interpolate_spectrum(
295
- spectrum: NDArray,
298
+ spectrum: BackendArray,
296
299
  shape: Tuple[int],
297
300
  shape_is_real_fourier: bool = True,
298
301
  order: int = 1,
299
- ) -> NDArray:
302
+ ) -> BackendArray:
300
303
  """
301
304
  References
302
305
  ----------
@@ -306,19 +309,19 @@ class LinearWhiteningFilter:
306
309
  """
307
310
  grid = fftfreqn(
308
311
  shape=shape,
309
- sampling_rate=.5,
312
+ sampling_rate=0.5,
310
313
  shape_is_real_fourier=shape_is_real_fourier,
311
314
  compute_euclidean_norm=True,
312
315
  )
313
- grid = backend.to_numpy_array(grid)
314
- np.multiply(grid, (spectrum.shape[0] - 1), out = grid) + 0.5
316
+ grid = be.to_numpy_array(grid)
317
+ np.multiply(grid, (spectrum.shape[0] - 1), out=grid) + 0.5
315
318
  spectrum = map_coordinates(spectrum, grid.reshape(1, -1), order=order)
316
319
  return spectrum.reshape(grid.shape)
317
320
 
318
321
  def __call__(
319
322
  self,
320
- data: NDArray = None,
321
- data_rfft: NDArray = None,
323
+ data: BackendArray = None,
324
+ data_rfft: BackendArray = None,
322
325
  n_bins: int = None,
323
326
  batch_dimension: int = None,
324
327
  order: int = 1,
@@ -329,9 +332,9 @@ class LinearWhiteningFilter:
329
332
 
330
333
  Parameters:
331
334
  -----------
332
- data : NDArray, optional
335
+ data : BackendArray, optional
333
336
  The input data, defaults to None.
334
- data_rfft : NDArray, optional
337
+ data_rfft : BackendArray, optional
335
338
  The Fourier transform of the input data, defaults to None.
336
339
  n_bins : int, optional
337
340
  The number of bins for computing the spectrum, defaults to None.
@@ -348,9 +351,9 @@ class LinearWhiteningFilter:
348
351
  Filter data and associated parameters.
349
352
  """
350
353
  if data_rfft is None:
351
- data_rfft = np.fft.rfftn(backend.to_numpy_array(data))
354
+ data_rfft = np.fft.rfftn(be.to_numpy_array(data))
352
355
 
353
- data_rfft = backend.to_numpy_array(data_rfft)
356
+ data_rfft = be.to_numpy_array(data_rfft)
354
357
 
355
358
  bins, radial_averages = self._compute_spectrum(
356
359
  data_rfft, n_bins, batch_dimension
@@ -358,21 +361,28 @@ class LinearWhiteningFilter:
358
361
 
359
362
  if order is None:
360
363
  cutoff = bins < radial_averages.size
361
- filter_mask = np.zeros(data_rfft.shape, radial_averages.dtype)
364
+ filter_mask = np.zeros(bins.shape, radial_averages.dtype)
362
365
  filter_mask[cutoff] = radial_averages[bins[cutoff]]
363
366
  else:
367
+ shape = bins.shape
368
+ if kwargs.get("shape", False):
369
+ shape = compute_fourier_shape(
370
+ shape=kwargs.get("shape"),
371
+ shape_is_real_fourier=kwargs.get("shape_is_real_fourier", False),
372
+ )
373
+
364
374
  filter_mask = self._interpolate_spectrum(
365
375
  spectrum=radial_averages,
366
- shape=data_rfft.shape,
376
+ shape=shape,
367
377
  shape_is_real_fourier=True,
368
378
  )
369
379
 
370
380
  filter_mask = np.fft.ifftshift(
371
381
  filter_mask,
372
- axes=tuple(i for i in range(data_rfft.ndim - 1) if i != batch_dimension)
382
+ axes=tuple(i for i in range(data_rfft.ndim - 1) if i != batch_dimension),
373
383
  )
374
384
 
375
385
  return {
376
- "data": backend.to_backend_array(filter_mask),
386
+ "data": be.to_backend_array(filter_mask),
377
387
  "is_multiplicative_filter": True,
378
388
  }