pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/match_template.py +147 -93
  2. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/postprocess.py +67 -26
  3. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +175 -85
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +20 -13
  8. scripts/match_template.py +147 -93
  9. scripts/match_template_filters.py +154 -95
  10. scripts/postprocess.py +67 -26
  11. scripts/preprocessor_gui.py +175 -85
  12. scripts/refine_matches.py +265 -61
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +451 -809
  16. tme/backends/__init__.py +40 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +111 -223
  19. tme/backends/jax_backend.py +214 -150
  20. tme/backends/matching_backend.py +445 -384
  21. tme/backends/mlx_backend.py +32 -59
  22. tme/backends/npfftw_backend.py +239 -507
  23. tme/backends/pytorch_backend.py +21 -145
  24. tme/density.py +233 -363
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +322 -285
  27. tme/matching_exhaustive.py +172 -1493
  28. tme/matching_optimization.py +143 -106
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +280 -386
  31. tme/memory.py +377 -0
  32. tme/orientations.py +52 -12
  33. tme/parser.py +3 -4
  34. tme/preprocessing/_utils.py +61 -32
  35. tme/preprocessing/compose.py +7 -3
  36. tme/preprocessing/frequency_filters.py +49 -39
  37. tme/preprocessing/tilt_series.py +34 -40
  38. tme/preprocessor.py +560 -526
  39. tme/structure.py +491 -188
  40. tme/types.py +5 -3
  41. pytme-0.2.1.dist-info/METADATA +0 -73
  42. pytme-0.2.1.dist-info/RECORD +0 -73
  43. tme/helpers.py +0 -881
  44. tme/matching_constrained.py +0 -195
  45. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  46. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  47. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  48. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  49. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
@@ -8,12 +8,12 @@ from math import log, sqrt
8
8
  from typing import Tuple, Dict
9
9
 
10
10
  import numpy as np
11
- from numpy.typing import NDArray
12
11
  from scipy.ndimage import mean as ndimean
13
12
  from scipy.ndimage import map_coordinates
14
13
 
15
- from ._utils import fftfreqn, crop_real_fourier, shift_fourier
16
- from ..backends import backend
14
+ from ..types import BackendArray
15
+ from ..backends import backend as be
16
+ from ._utils import fftfreqn, crop_real_fourier, shift_fourier, compute_fourier_shape
17
17
 
18
18
 
19
19
  class BandPassFilter:
@@ -29,7 +29,7 @@ class BandPassFilter:
29
29
  highpass : float, optional
30
30
  The highpass cutoff, defaults to None.
31
31
  sampling_rate : Tuple[float], optional
32
- The sampling rate in Fourier space, defaults to 1.
32
+ The sampling r_position_to_molmapate in Fourier space, defaults to 1.
33
33
  use_gaussian : bool, optional
34
34
  Whether to use Gaussian bandpass filter, defaults to True.
35
35
  return_real_fourier : bool, optional
@@ -63,7 +63,7 @@ class BandPassFilter:
63
63
  return_real_fourier: bool = False,
64
64
  shape_is_real_fourier: bool = False,
65
65
  **kwargs,
66
- ) -> NDArray:
66
+ ) -> BackendArray:
67
67
  """
68
68
  Generate a bandpass filter using discrete frequency cutoffs.
69
69
 
@@ -86,7 +86,7 @@ class BandPassFilter:
86
86
 
87
87
  Returns:
88
88
  --------
89
- NDArray
89
+ BackendArray
90
90
  The bandpass filter in Fourier space.
91
91
  """
92
92
  if shape_is_real_fourier:
@@ -127,7 +127,7 @@ class BandPassFilter:
127
127
  return_real_fourier: bool = False,
128
128
  shape_is_real_fourier: bool = False,
129
129
  **kwargs,
130
- ) -> NDArray:
130
+ ) -> BackendArray:
131
131
  """
132
132
  Generate a bandpass filter using Gaussian functions.
133
133
 
@@ -150,7 +150,7 @@ class BandPassFilter:
150
150
 
151
151
  Returns:
152
152
  --------
153
- NDArray
153
+ BackendArray
154
154
  The bandpass filter in Fourier space.
155
155
  """
156
156
  if shape_is_real_fourier:
@@ -162,29 +162,32 @@ class BandPassFilter:
162
162
  shape_is_real_fourier=shape_is_real_fourier,
163
163
  compute_euclidean_norm=True,
164
164
  )
165
- grid = -backend.square(grid)
165
+ grid = be.to_backend_array(grid)
166
+ grid = -be.square(grid)
166
167
 
167
168
  lowpass_filter, highpass_filter = 1, 1
168
169
  norm = float(sqrt(2 * log(2)))
169
- upper_sampling = float(backend.max(backend.multiply(2, sampling_rate)))
170
+ upper_sampling = float(
171
+ be.max(be.multiply(2, be.to_backend_array(sampling_rate)))
172
+ )
170
173
 
171
174
  if lowpass is not None:
172
175
  lowpass = float(lowpass)
173
- lowpass = backend.maximum(lowpass, backend.eps(backend._float_dtype))
176
+ lowpass = be.maximum(lowpass, be.eps(be._float_dtype))
174
177
  if highpass is not None:
175
178
  highpass = float(highpass)
176
- highpass = backend.maximum(highpass, backend.eps(backend._float_dtype))
179
+ highpass = be.maximum(highpass, be.eps(be._float_dtype))
177
180
 
178
181
  if lowpass is not None:
179
182
  lowpass = upper_sampling / (lowpass * norm)
180
- lowpass = backend.multiply(2, backend.square(lowpass))
181
- lowpass_filter = backend.exp(backend.divide(grid, lowpass))
183
+ lowpass = be.multiply(2, be.square(lowpass))
184
+ lowpass_filter = be.exp(be.divide(grid, lowpass))
182
185
  if highpass is not None:
183
186
  highpass = upper_sampling / (highpass * norm)
184
- highpass = backend.multiply(2, backend.square(highpass))
185
- highpass_filter = 1 - backend.exp(backend.divide(grid, highpass))
187
+ highpass = be.multiply(2, be.square(highpass))
188
+ highpass_filter = 1 - be.exp(be.divide(grid, highpass))
186
189
 
187
- bandpass_filter = backend.multiply(lowpass_filter, highpass_filter)
190
+ bandpass_filter = be.multiply(lowpass_filter, highpass_filter)
188
191
  bandpass_filter = shift_fourier(
189
192
  data=bandpass_filter, shape_is_real_fourier=shape_is_real_fourier
190
193
  )
@@ -205,7 +208,7 @@ class BandPassFilter:
205
208
  mask = func(**func_args)
206
209
 
207
210
  return {
208
- "data": backend.to_backend_array(mask),
211
+ "data": be.to_backend_array(mask),
209
212
  "sampling_rate": func_args.get("sampling_rate", 1),
210
213
  "is_multiplicative_filter": True,
211
214
  }
@@ -237,14 +240,14 @@ class LinearWhiteningFilter:
237
240
 
238
241
  @staticmethod
239
242
  def _compute_spectrum(
240
- data_rfft: NDArray, n_bins: int = None, batch_dimension: int = None
241
- ) -> Tuple[NDArray, NDArray]:
243
+ data_rfft: BackendArray, n_bins: int = None, batch_dimension: int = None
244
+ ) -> Tuple[BackendArray, BackendArray]:
242
245
  """
243
246
  Compute the spectrum of the input data.
244
247
 
245
248
  Parameters:
246
249
  -----------
247
- data_rfft : NDArray
250
+ data_rfft : BackendArray
248
251
  The Fourier transform of the input data.
249
252
  n_bins : int, optional
250
253
  The number of bins for computing the spectrum, defaults to None.
@@ -253,9 +256,9 @@ class LinearWhiteningFilter:
253
256
 
254
257
  Returns:
255
258
  --------
256
- bins : NDArray
259
+ bins : BackendArray
257
260
  Array containing the bin indices for the spectrum.
258
- radial_averages : NDArray
261
+ radial_averages : BackendArray
259
262
  Array containing the radial averages of the spectrum.
260
263
  """
261
264
  shape = tuple(x for i, x in enumerate(data_rfft.shape) if i != batch_dimension)
@@ -270,7 +273,7 @@ class LinearWhiteningFilter:
270
273
  shape_is_real_fourier=True,
271
274
  compute_euclidean_norm=True,
272
275
  )
273
- bins = backend.to_numpy_array(bins)
276
+ bins = be.to_numpy_array(bins)
274
277
 
275
278
  # Implicit lowpass to nyquist
276
279
  bins = np.floor(bins * (n_bins - 1) + 0.5).astype(int)
@@ -292,11 +295,11 @@ class LinearWhiteningFilter:
292
295
 
293
296
  @staticmethod
294
297
  def _interpolate_spectrum(
295
- spectrum: NDArray,
298
+ spectrum: BackendArray,
296
299
  shape: Tuple[int],
297
300
  shape_is_real_fourier: bool = True,
298
301
  order: int = 1,
299
- ) -> NDArray:
302
+ ) -> BackendArray:
300
303
  """
301
304
  References
302
305
  ----------
@@ -306,19 +309,19 @@ class LinearWhiteningFilter:
306
309
  """
307
310
  grid = fftfreqn(
308
311
  shape=shape,
309
- sampling_rate=.5,
312
+ sampling_rate=0.5,
310
313
  shape_is_real_fourier=shape_is_real_fourier,
311
314
  compute_euclidean_norm=True,
312
315
  )
313
- grid = backend.to_numpy_array(grid)
314
- np.multiply(grid, (spectrum.shape[0] - 1), out = grid) + 0.5
316
+ grid = be.to_numpy_array(grid)
317
+ np.multiply(grid, (spectrum.shape[0] - 1), out=grid) + 0.5
315
318
  spectrum = map_coordinates(spectrum, grid.reshape(1, -1), order=order)
316
319
  return spectrum.reshape(grid.shape)
317
320
 
318
321
  def __call__(
319
322
  self,
320
- data: NDArray = None,
321
- data_rfft: NDArray = None,
323
+ data: BackendArray = None,
324
+ data_rfft: BackendArray = None,
322
325
  n_bins: int = None,
323
326
  batch_dimension: int = None,
324
327
  order: int = 1,
@@ -329,9 +332,9 @@ class LinearWhiteningFilter:
329
332
 
330
333
  Parameters:
331
334
  -----------
332
- data : NDArray, optional
335
+ data : BackendArray, optional
333
336
  The input data, defaults to None.
334
- data_rfft : NDArray, optional
337
+ data_rfft : BackendArray, optional
335
338
  The Fourier transform of the input data, defaults to None.
336
339
  n_bins : int, optional
337
340
  The number of bins for computing the spectrum, defaults to None.
@@ -348,9 +351,9 @@ class LinearWhiteningFilter:
348
351
  Filter data and associated parameters.
349
352
  """
350
353
  if data_rfft is None:
351
- data_rfft = np.fft.rfftn(backend.to_numpy_array(data))
354
+ data_rfft = np.fft.rfftn(be.to_numpy_array(data))
352
355
 
353
- data_rfft = backend.to_numpy_array(data_rfft)
356
+ data_rfft = be.to_numpy_array(data_rfft)
354
357
 
355
358
  bins, radial_averages = self._compute_spectrum(
356
359
  data_rfft, n_bins, batch_dimension
@@ -358,21 +361,28 @@ class LinearWhiteningFilter:
358
361
 
359
362
  if order is None:
360
363
  cutoff = bins < radial_averages.size
361
- filter_mask = np.zeros(data_rfft.shape, radial_averages.dtype)
364
+ filter_mask = np.zeros(bins.shape, radial_averages.dtype)
362
365
  filter_mask[cutoff] = radial_averages[bins[cutoff]]
363
366
  else:
367
+ shape = bins.shape
368
+ if kwargs.get("shape", False):
369
+ shape = compute_fourier_shape(
370
+ shape=kwargs.get("shape"),
371
+ shape_is_real_fourier=kwargs.get("shape_is_real_fourier", False),
372
+ )
373
+
364
374
  filter_mask = self._interpolate_spectrum(
365
375
  spectrum=radial_averages,
366
- shape=data_rfft.shape,
376
+ shape=shape,
367
377
  shape_is_real_fourier=True,
368
378
  )
369
379
 
370
380
  filter_mask = np.fft.ifftshift(
371
381
  filter_mask,
372
- axes=tuple(i for i in range(data_rfft.ndim - 1) if i != batch_dimension)
382
+ axes=tuple(i for i in range(data_rfft.ndim - 1) if i != batch_dimension),
373
383
  )
374
384
 
375
385
  return {
376
- "data": backend.to_backend_array(filter_mask),
386
+ "data": be.to_backend_array(filter_mask),
377
387
  "is_multiplicative_filter": True,
378
388
  }
@@ -9,17 +9,15 @@ from typing import Tuple, Dict
9
9
  from dataclasses import dataclass
10
10
 
11
11
  import numpy as np
12
- from numpy.typing import NDArray
13
12
 
14
13
  from .. import Preprocessor
15
- from ..backends import backend
14
+ from ..types import NDArray
15
+ from ..backends import backend as be
16
16
  from ..matching_utils import euler_to_rotationmatrix
17
-
18
17
  from ._utils import (
19
18
  frequency_grid_at_angle,
20
19
  compute_tilt_shape,
21
20
  crop_real_fourier,
22
- centered_grid,
23
21
  fftfreqn,
24
22
  shift_fourier,
25
23
  )
@@ -51,7 +49,6 @@ def create_reconstruction_filter(
51
49
  +---------------+----------------------------------------------------+
52
50
  | hamming | |w| * (.54 + .46 ( cos(|w| * pi))) [2]_ |
53
51
  +---------------+----------------------------------------------------+
54
-
55
52
  kwargs: Dict
56
53
  Keyword arguments for particular filter_types.
57
54
 
@@ -195,22 +192,20 @@ class ReconstructFromTilt:
195
192
  if data.shape == shape:
196
193
  return data
197
194
 
198
- data = backend.to_backend_array(data)
199
- volume_temp = backend.zeros(shape, dtype=backend._float_dtype)
200
- volume_temp_rotated = backend.zeros(shape, dtype=backend._float_dtype)
201
- volume = backend.zeros(shape, dtype=backend._float_dtype)
195
+ data = be.to_backend_array(data)
196
+ volume_temp = be.zeros(shape, dtype=be._float_dtype)
197
+ volume_temp_rotated = be.zeros(shape, dtype=be._float_dtype)
198
+ volume = be.zeros(shape, dtype=be._float_dtype)
202
199
 
203
- slices = tuple(
204
- slice(a, a + 1) for a in backend.astype(backend.divide(shape, 2), int)
205
- )
200
+ slices = tuple(slice(a, a + 1) for a in be.astype(be.divide(shape, 2), int))
206
201
  subset = tuple(
207
202
  slice(None) if i != opening_axis else slices[opening_axis]
208
203
  for i in range(len(shape))
209
204
  )
210
- angles_loop = backend.zeros(len(shape))
205
+ angles_loop = be.zeros(len(shape))
211
206
  wedge_dim = [x for x in data.shape]
212
207
  wedge_dim.insert(1 + opening_axis, 1)
213
- wedges = backend.reshape(data, wedge_dim)
208
+ wedges = be.reshape(data, wedge_dim)
214
209
 
215
210
  rec_filter = 1
216
211
  if reconstruction_filter is not None:
@@ -226,31 +221,29 @@ class ReconstructFromTilt:
226
221
  if tilt_axis == 1 and opening_axis == 0:
227
222
  rec_filter = rec_filter.T
228
223
 
229
- rec_filter = backend.to_backend_array(rec_filter)
230
- rec_filter = backend.reshape(rec_filter, wedges[0].shape)
224
+ rec_filter = be.to_backend_array(rec_filter)
225
+ rec_filter = be.reshape(rec_filter, wedges[0].shape)
231
226
 
232
227
  for index in range(len(angles)):
233
- backend.fill(angles_loop, 0)
234
- backend.fill(volume_temp, 0)
235
- backend.fill(volume_temp_rotated, 0)
228
+ be.fill(angles_loop, 0)
229
+ be.fill(volume_temp, 0)
230
+ be.fill(volume_temp_rotated, 0)
236
231
 
237
232
  volume_temp[subset] = wedges[index] * rec_filter
238
233
 
239
234
  angles_loop[tilt_axis] = angles[index]
240
- angles_loop = backend.roll(angles_loop, (opening_axis - 1,), axis=0)
241
- rotation_matrix = euler_to_rotationmatrix(
242
- backend.to_numpy_array(angles_loop)
243
- )
244
- rotation_matrix = backend.to_backend_array(rotation_matrix)
235
+ angles_loop = be.roll(angles_loop, (opening_axis - 1,), axis=0)
236
+ rotation_matrix = euler_to_rotationmatrix(be.to_numpy_array(angles_loop))
237
+ rotation_matrix = be.to_backend_array(rotation_matrix)
245
238
 
246
- backend.rotate_array(
239
+ be.rigid_transform(
247
240
  arr=volume_temp,
248
241
  rotation_matrix=rotation_matrix,
249
242
  out=volume_temp_rotated,
250
243
  use_geometric_center=True,
251
244
  order=interpolation_order,
252
245
  )
253
- backend.add(volume, volume_temp_rotated, out=volume)
246
+ be.add(volume, volume_temp_rotated, out=volume)
254
247
 
255
248
  volume = shift_fourier(data=volume, shape_is_real_fourier=False)
256
249
 
@@ -387,7 +380,7 @@ class Wedge:
387
380
  func_args["weights"] = np.cos(np.radians(self.angles))
388
381
 
389
382
  ret = weight_types[weight_type](**func_args)
390
- ret = backend.astype(backend.to_backend_array(ret), backend._float_dtype)
383
+ ret = be.astype(be.to_backend_array(ret), be._float_dtype)
391
384
 
392
385
  return {
393
386
  "data": ret,
@@ -483,7 +476,7 @@ class Wedge:
483
476
  reduce_dim=True,
484
477
  )
485
478
 
486
- wedges = np.zeros((len(self.angles), *tilt_shape), dtype=backend._float_dtype)
479
+ wedges = np.zeros((len(self.angles), *tilt_shape), dtype=be._float_dtype)
487
480
  for index, angle in enumerate(self.angles):
488
481
  frequency_grid = frequency_grid_at_angle(
489
482
  shape=self.shape,
@@ -573,7 +566,7 @@ class WedgeReconstructed:
573
566
  func = self.continuous_wedge
574
567
 
575
568
  ret = func(shape=shape, **func_args)
576
- ret = backend.astype(backend.to_backend_array(ret), backend._float_dtype)
569
+ ret = be.astype(be.to_backend_array(ret), be._float_dtype)
577
570
 
578
571
  return {
579
572
  "data": ret,
@@ -664,7 +657,7 @@ class WedgeReconstructed:
664
657
  """
665
658
  preprocessor = Preprocessor()
666
659
 
667
- angles = np.asarray(backend.to_numpy_array(angles))
660
+ angles = np.asarray(be.to_numpy_array(angles))
668
661
  weights = np.ones(angles.size)
669
662
  if weight_wedge:
670
663
  weights = np.cos(np.radians(angles))
@@ -858,7 +851,7 @@ class CTF:
858
851
  func_args["opening_axis"] = None
859
852
 
860
853
  ret = self.weight(**func_args)
861
- ret = backend.astype(backend.to_backend_array(ret), backend._float_dtype)
854
+ ret = be.astype(be.to_backend_array(ret), be._float_dtype)
862
855
  return {
863
856
  "data": ret,
864
857
  "angles": func_args["angles"],
@@ -941,24 +934,26 @@ class CTF:
941
934
  shape=shape, opening_axis=opening_axis, reduce_dim=True
942
935
  )
943
936
  stack = np.zeros((len(angles), *tilt_shape))
944
- electron_wavelength = self._compute_electron_wavelength() / sampling_rate
945
937
 
946
938
  correct_defocus_gradient &= len(shape) == 3
947
939
  correct_defocus_gradient &= tilt_axis is not None
948
940
  correct_defocus_gradient &= opening_axis is not None
949
941
 
950
942
  for index, angle in enumerate(angles):
951
- grid = backend.to_numpy_array(centered_grid(shape=tilt_shape))
952
- grid = np.divide(grid.T, sampling_rate).T
953
-
954
943
  defocus_x, defocus_y = defoci_x[index], defoci_y[index]
955
944
 
945
+ if correct_defocus_gradient or defocus_y is not None:
946
+ grid = fftfreqn(
947
+ shape=shape,
948
+ sampling_rate=be.divide(sampling_rate, shape),
949
+ return_sparse_grid=True,
950
+ )
951
+
956
952
  # This should be done after defocus_x computation
957
953
  if correct_defocus_gradient:
958
954
  angle_rad = np.radians(angle)
959
955
 
960
956
  defocus_gradient = np.multiply(grid[1], np.sin(angle_rad))
961
-
962
957
  remaining_axis = tuple(
963
958
  i for i in range(len(shape)) if i not in (opening_axis, tilt_axis)
964
959
  )[0]
@@ -983,7 +978,6 @@ class CTF:
983
978
  angle=angle,
984
979
  sampling_rate=1,
985
980
  )
986
- frequency_grid *= frequency_grid <= 0.5
987
981
  np.square(frequency_grid, out=frequency_grid)
988
982
 
989
983
  electron_aberration = spherical_aberration * electron_wavelength**2
@@ -1001,13 +995,13 @@ class CTF:
1001
995
  np.sin(-chi, out=chi)
1002
996
  stack[index] = chi
1003
997
 
998
+ # Avoid contrast inversion
999
+ np.negative(stack, out=stack)
1004
1000
  if flip_phase:
1005
1001
  np.abs(stack, out=stack)
1006
1002
 
1007
- np.negative(stack, out=stack)
1008
1003
  stack = np.squeeze(stack)
1009
-
1010
- stack = backend.to_backend_array(stack)
1004
+ stack = be.to_backend_array(stack)
1011
1005
 
1012
1006
  if len(angles) == 1:
1013
1007
  stack = shift_fourier(data=stack, shape_is_real_fourier=False)