pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/match_template.py +219 -216
  2. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/postprocess.py +86 -54
  3. pytme-0.2.3.data/scripts/preprocess.py +132 -0
  4. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +181 -94
  5. pytme-0.2.3.dist-info/METADATA +92 -0
  6. pytme-0.2.3.dist-info/RECORD +75 -0
  7. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/WHEEL +1 -1
  8. pytme-0.2.1.data/scripts/preprocess.py → scripts/eval.py +1 -1
  9. scripts/extract_candidates.py +20 -13
  10. scripts/match_template.py +219 -216
  11. scripts/match_template_filters.py +154 -95
  12. scripts/postprocess.py +86 -54
  13. scripts/preprocess.py +95 -56
  14. scripts/preprocessor_gui.py +181 -94
  15. scripts/refine_matches.py +265 -61
  16. tme/__init__.py +0 -1
  17. tme/__version__.py +1 -1
  18. tme/analyzer.py +458 -813
  19. tme/backends/__init__.py +40 -11
  20. tme/backends/_jax_utils.py +187 -0
  21. tme/backends/cupy_backend.py +109 -226
  22. tme/backends/jax_backend.py +230 -152
  23. tme/backends/matching_backend.py +445 -384
  24. tme/backends/mlx_backend.py +32 -59
  25. tme/backends/npfftw_backend.py +240 -507
  26. tme/backends/pytorch_backend.py +30 -151
  27. tme/density.py +248 -371
  28. tme/extensions.cpython-311-darwin.so +0 -0
  29. tme/matching_data.py +328 -284
  30. tme/matching_exhaustive.py +195 -1499
  31. tme/matching_optimization.py +143 -106
  32. tme/matching_scores.py +887 -0
  33. tme/matching_utils.py +287 -388
  34. tme/memory.py +377 -0
  35. tme/orientations.py +78 -21
  36. tme/parser.py +3 -4
  37. tme/preprocessing/_utils.py +61 -32
  38. tme/preprocessing/composable_filter.py +7 -4
  39. tme/preprocessing/compose.py +7 -3
  40. tme/preprocessing/frequency_filters.py +49 -39
  41. tme/preprocessing/tilt_series.py +44 -72
  42. tme/preprocessor.py +560 -526
  43. tme/structure.py +491 -188
  44. tme/types.py +5 -3
  45. pytme-0.2.1.dist-info/METADATA +0 -73
  46. pytme-0.2.1.dist-info/RECORD +0 -73
  47. tme/helpers.py +0 -881
  48. tme/matching_constrained.py +0 -195
  49. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
  50. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
@@ -9,17 +9,15 @@ from typing import Tuple, Dict
9
9
  from dataclasses import dataclass
10
10
 
11
11
  import numpy as np
12
- from numpy.typing import NDArray
13
12
 
14
13
  from .. import Preprocessor
15
- from ..backends import backend
14
+ from ..types import NDArray
15
+ from ..backends import backend as be
16
16
  from ..matching_utils import euler_to_rotationmatrix
17
-
18
17
  from ._utils import (
19
18
  frequency_grid_at_angle,
20
19
  compute_tilt_shape,
21
20
  crop_real_fourier,
22
- centered_grid,
23
21
  fftfreqn,
24
22
  shift_fourier,
25
23
  )
@@ -51,7 +49,6 @@ def create_reconstruction_filter(
51
49
  +---------------+----------------------------------------------------+
52
50
  | hamming | |w| * (.54 + .46 ( cos(|w| * pi))) [2]_ |
53
51
  +---------------+----------------------------------------------------+
54
-
55
52
  kwargs: Dict
56
53
  Keyword arguments for particular filter_types.
57
54
 
@@ -195,22 +192,20 @@ class ReconstructFromTilt:
195
192
  if data.shape == shape:
196
193
  return data
197
194
 
198
- data = backend.to_backend_array(data)
199
- volume_temp = backend.zeros(shape, dtype=backend._float_dtype)
200
- volume_temp_rotated = backend.zeros(shape, dtype=backend._float_dtype)
201
- volume = backend.zeros(shape, dtype=backend._float_dtype)
195
+ data = be.to_backend_array(data)
196
+ volume_temp = be.zeros(shape, dtype=be._float_dtype)
197
+ volume_temp_rotated = be.zeros(shape, dtype=be._float_dtype)
198
+ volume = be.zeros(shape, dtype=be._float_dtype)
202
199
 
203
- slices = tuple(
204
- slice(a, a + 1) for a in backend.astype(backend.divide(shape, 2), int)
205
- )
200
+ slices = tuple(slice(a, a + 1) for a in be.astype(be.divide(shape, 2), int))
206
201
  subset = tuple(
207
202
  slice(None) if i != opening_axis else slices[opening_axis]
208
203
  for i in range(len(shape))
209
204
  )
210
- angles_loop = backend.zeros(len(shape))
205
+ angles_loop = be.zeros(len(shape))
211
206
  wedge_dim = [x for x in data.shape]
212
207
  wedge_dim.insert(1 + opening_axis, 1)
213
- wedges = backend.reshape(data, wedge_dim)
208
+ wedges = be.reshape(data, wedge_dim)
214
209
 
215
210
  rec_filter = 1
216
211
  if reconstruction_filter is not None:
@@ -226,31 +221,29 @@ class ReconstructFromTilt:
226
221
  if tilt_axis == 1 and opening_axis == 0:
227
222
  rec_filter = rec_filter.T
228
223
 
229
- rec_filter = backend.to_backend_array(rec_filter)
230
- rec_filter = backend.reshape(rec_filter, wedges[0].shape)
224
+ rec_filter = be.to_backend_array(rec_filter)
225
+ rec_filter = be.reshape(rec_filter, wedges[0].shape)
231
226
 
232
227
  for index in range(len(angles)):
233
- backend.fill(angles_loop, 0)
234
- backend.fill(volume_temp, 0)
235
- backend.fill(volume_temp_rotated, 0)
228
+ be.fill(angles_loop, 0)
229
+ be.fill(volume_temp, 0)
230
+ be.fill(volume_temp_rotated, 0)
236
231
 
237
232
  volume_temp[subset] = wedges[index] * rec_filter
238
233
 
239
234
  angles_loop[tilt_axis] = angles[index]
240
- angles_loop = backend.roll(angles_loop, (opening_axis - 1,), axis=0)
241
- rotation_matrix = euler_to_rotationmatrix(
242
- backend.to_numpy_array(angles_loop)
243
- )
244
- rotation_matrix = backend.to_backend_array(rotation_matrix)
235
+ angles_loop = be.roll(angles_loop, (opening_axis - 1,), axis=0)
236
+ rotation_matrix = euler_to_rotationmatrix(be.to_numpy_array(angles_loop))
237
+ rotation_matrix = be.to_backend_array(rotation_matrix)
245
238
 
246
- backend.rotate_array(
239
+ be.rigid_transform(
247
240
  arr=volume_temp,
248
241
  rotation_matrix=rotation_matrix,
249
242
  out=volume_temp_rotated,
250
243
  use_geometric_center=True,
251
244
  order=interpolation_order,
252
245
  )
253
- backend.add(volume, volume_temp_rotated, out=volume)
246
+ be.add(volume, volume_temp_rotated, out=volume)
254
247
 
255
248
  volume = shift_fourier(data=volume, shape_is_real_fourier=False)
256
249
 
@@ -387,7 +380,7 @@ class Wedge:
387
380
  func_args["weights"] = np.cos(np.radians(self.angles))
388
381
 
389
382
  ret = weight_types[weight_type](**func_args)
390
- ret = backend.astype(backend.to_backend_array(ret), backend._float_dtype)
383
+ ret = be.astype(be.to_backend_array(ret), be._float_dtype)
391
384
 
392
385
  return {
393
386
  "data": ret,
@@ -483,7 +476,7 @@ class Wedge:
483
476
  reduce_dim=True,
484
477
  )
485
478
 
486
- wedges = np.zeros((len(self.angles), *tilt_shape), dtype=backend._float_dtype)
479
+ wedges = np.zeros((len(self.angles), *tilt_shape), dtype=be._float_dtype)
487
480
  for index, angle in enumerate(self.angles):
488
481
  frequency_grid = frequency_grid_at_angle(
489
482
  shape=self.shape,
@@ -573,7 +566,7 @@ class WedgeReconstructed:
573
566
  func = self.continuous_wedge
574
567
 
575
568
  ret = func(shape=shape, **func_args)
576
- ret = backend.astype(backend.to_backend_array(ret), backend._float_dtype)
569
+ ret = be.astype(be.to_backend_array(ret), be._float_dtype)
577
570
 
578
571
  return {
579
572
  "data": ret,
@@ -664,7 +657,7 @@ class WedgeReconstructed:
664
657
  """
665
658
  preprocessor = Preprocessor()
666
659
 
667
- angles = np.asarray(backend.to_numpy_array(angles))
660
+ angles = np.asarray(be.to_numpy_array(angles))
668
661
  weights = np.ones(angles.size)
669
662
  if weight_wedge:
670
663
  weights = np.cos(np.radians(angles))
@@ -680,6 +673,9 @@ class WedgeReconstructed:
680
673
  omit_negative_frequencies=return_real_fourier,
681
674
  )
682
675
 
676
+ if not weight_wedge:
677
+ ret = (ret > 0) * 1.0
678
+
683
679
  return ret
684
680
 
685
681
 
@@ -797,16 +793,6 @@ class CTF:
797
793
  def __post_init__(self):
798
794
  self.defocus_angle = np.radians(self.defocus_angle)
799
795
 
800
- kwargs = {
801
- "defocus_x": self.defocus_x,
802
- "defocus_y": self.defocus_y,
803
- "spherical_aberration": self.spherical_aberration,
804
- }
805
- kwargs = {k: v for k, v in kwargs.items() if v is not None}
806
- self._update_parameters(
807
- electron_wavelength=self._compute_electron_wavelength(), **kwargs
808
- )
809
-
810
796
  def _compute_electron_wavelength(self, acceleration_voltage: int = None):
811
797
  """Computes the wavelength of an electron in angstrom."""
812
798
 
@@ -825,28 +811,10 @@ class CTF:
825
811
  electron_wavelength = np.divide(
826
812
  planck_constant * light_velocity, np.sqrt(denominator)
827
813
  )
814
+ # Convert to Ångstrom
828
815
  electron_wavelength *= 1e10
829
816
  return electron_wavelength
830
817
 
831
- def _update_parameters(self, **kwargs):
832
- """Update multiple parameters of the CTF instance."""
833
- voxel_based = [
834
- "electron_wavelength",
835
- "spherical_aberration",
836
- "defocus_x",
837
- "defocus_y",
838
- ]
839
- if "sampling_rate" in kwargs:
840
- self.sampling_rate = kwargs["sampling_rate"]
841
-
842
- if "acceleration_voltage" in kwargs:
843
- kwargs["electron_wavelength"] = self._compute_electron_wavelength()
844
-
845
- for key, value in kwargs.items():
846
- if key in voxel_based and value is not None:
847
- value = np.divide(value, np.max(self.sampling_rate))
848
- setattr(self, key, value)
849
-
850
818
  def __call__(self, **kwargs) -> NDArray:
851
819
  func_args = vars(self).copy()
852
820
  func_args.update(kwargs)
@@ -858,7 +826,7 @@ class CTF:
858
826
  func_args["opening_axis"] = None
859
827
 
860
828
  ret = self.weight(**func_args)
861
- ret = backend.astype(backend.to_backend_array(ret), backend._float_dtype)
829
+ ret = be.astype(be.to_backend_array(ret), be._float_dtype)
862
830
  return {
863
831
  "data": ret,
864
832
  "angles": func_args["angles"],
@@ -872,7 +840,6 @@ class CTF:
872
840
  shape: Tuple[int],
873
841
  defocus_x: Tuple[float],
874
842
  angles: Tuple[float],
875
- electron_wavelength: float = None,
876
843
  opening_axis: int = None,
877
844
  tilt_axis: int = None,
878
845
  amplitude_contrast: float = 0.07,
@@ -898,8 +865,6 @@ class CTF:
898
865
  The defocus value in x direction.
899
866
  angles : tuple of float
900
867
  The tilt angles.
901
- electron_wavelength : float, optional
902
- The electron wavelength, defaults to None.
903
868
  opening_axis : int, optional
904
869
  The axis around which the wedge is opened, defaults to None.
905
870
  tilt_axis : int, optional
@@ -941,24 +906,30 @@ class CTF:
941
906
  shape=shape, opening_axis=opening_axis, reduce_dim=True
942
907
  )
943
908
  stack = np.zeros((len(angles), *tilt_shape))
944
- electron_wavelength = self._compute_electron_wavelength() / sampling_rate
945
909
 
946
910
  correct_defocus_gradient &= len(shape) == 3
947
911
  correct_defocus_gradient &= tilt_axis is not None
948
912
  correct_defocus_gradient &= opening_axis is not None
949
913
 
914
+ spherical_aberration /= sampling_rate
915
+ electron_wavelength = self._compute_electron_wavelength() / sampling_rate
950
916
  for index, angle in enumerate(angles):
951
- grid = backend.to_numpy_array(centered_grid(shape=tilt_shape))
952
- grid = np.divide(grid.T, sampling_rate).T
953
-
954
917
  defocus_x, defocus_y = defoci_x[index], defoci_y[index]
955
918
 
919
+ defocus_x = defocus_x / sampling_rate if defocus_x is not None else None
920
+ defocus_y = defocus_y / sampling_rate if defocus_y is not None else None
921
+ if correct_defocus_gradient or defocus_y is not None:
922
+ grid = fftfreqn(
923
+ shape=shape,
924
+ sampling_rate=be.divide(sampling_rate, shape),
925
+ return_sparse_grid=True,
926
+ )
927
+
956
928
  # This should be done after defocus_x computation
957
929
  if correct_defocus_gradient:
958
930
  angle_rad = np.radians(angle)
959
931
 
960
932
  defocus_gradient = np.multiply(grid[1], np.sin(angle_rad))
961
-
962
933
  remaining_axis = tuple(
963
934
  i for i in range(len(shape)) if i not in (opening_axis, tilt_axis)
964
935
  )[0]
@@ -983,7 +954,7 @@ class CTF:
983
954
  angle=angle,
984
955
  sampling_rate=1,
985
956
  )
986
- frequency_grid *= frequency_grid <= 0.5
957
+ frequency_mask = frequency_grid < 0.5
987
958
  np.square(frequency_grid, out=frequency_grid)
988
959
 
989
960
  electron_aberration = spherical_aberration * electron_wavelength**2
@@ -999,15 +970,16 @@ class CTF:
999
970
  )
1000
971
  )
1001
972
  np.sin(-chi, out=chi)
973
+ np.multiply(chi, frequency_mask, out=chi)
1002
974
  stack[index] = chi
1003
975
 
976
+ # Avoid contrast inversion
977
+ np.negative(stack, out=stack)
1004
978
  if flip_phase:
1005
979
  np.abs(stack, out=stack)
1006
980
 
1007
- np.negative(stack, out=stack)
1008
981
  stack = np.squeeze(stack)
1009
-
1010
- stack = backend.to_backend_array(stack)
982
+ stack = be.to_backend_array(stack)
1011
983
 
1012
984
  if len(angles) == 1:
1013
985
  stack = shift_fourier(data=stack, shape_is_real_fourier=False)