pytme 0.2.2__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {pytme-0.2.2.data → pytme-0.2.3.data}/scripts/match_template.py +91 -142
  2. {pytme-0.2.2.data → pytme-0.2.3.data}/scripts/postprocess.py +20 -29
  3. pytme-0.2.3.data/scripts/preprocess.py +132 -0
  4. {pytme-0.2.2.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +6 -9
  5. {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/METADATA +11 -10
  6. {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/RECORD +33 -32
  7. pytme-0.2.2.data/scripts/preprocess.py → scripts/eval.py +1 -1
  8. scripts/match_template.py +91 -142
  9. scripts/postprocess.py +20 -29
  10. scripts/preprocess.py +95 -56
  11. scripts/preprocessor_gui.py +6 -9
  12. tme/__version__.py +1 -1
  13. tme/analyzer.py +9 -6
  14. tme/backends/__init__.py +1 -1
  15. tme/backends/_jax_utils.py +10 -8
  16. tme/backends/cupy_backend.py +2 -7
  17. tme/backends/jax_backend.py +34 -20
  18. tme/backends/npfftw_backend.py +3 -2
  19. tme/backends/pytorch_backend.py +10 -7
  20. tme/density.py +15 -8
  21. tme/extensions.cpython-311-darwin.so +0 -0
  22. tme/matching_data.py +24 -17
  23. tme/matching_exhaustive.py +36 -19
  24. tme/matching_scores.py +5 -2
  25. tme/matching_utils.py +7 -2
  26. tme/orientations.py +26 -9
  27. tme/preprocessing/composable_filter.py +7 -4
  28. tme/preprocessing/tilt_series.py +10 -32
  29. {pytme-0.2.2.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
  30. {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
  31. {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/WHEEL +0 -0
  32. {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
  33. {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
tme/matching_data.py CHANGED
@@ -171,6 +171,7 @@ class MatchingData:
171
171
  np.subtract(right_pad, data_voxels_right),
172
172
  )
173
173
  )
174
+ # The reflections are later cropped from the scores
174
175
  arr = np.pad(arr, padding, mode="reflect")
175
176
 
176
177
  if invert:
@@ -467,29 +468,35 @@ class MatchingData:
467
468
 
468
469
  pad_shape = np.maximum(target_shape, template_shape)
469
470
  ret = be.compute_convolution_shapes(pad_shape, fourier_pad)
470
- convolution_shape, fast_shape, fast_ft_shape = ret
471
+ conv_shape, fast_shape, fast_ft_shape = ret
472
+
473
+ template_mod = np.mod(template_shape, 2)
471
474
  if not pad_fourier:
472
475
  fourier_shift = 1 - np.divide(template_shape, 2).astype(int)
473
- fourier_shift -= np.mod(template_shape, 2)
474
- shape_diff = np.subtract(fast_shape, convolution_shape)
475
- shape_diff = np.divide(shape_diff, 2).astype(int)
476
- shape_diff = np.multiply(shape_diff, 1 - batch_mask)
477
- np.add(fourier_shift, shape_diff, out=fourier_shift)
478
-
479
- fourier_shift = fourier_shift.astype(int)
476
+ fourier_shift = np.subtract(fourier_shift, template_mod)
480
477
 
481
- shape_diff = np.subtract(target_shape, template_shape)
482
- shape_diff = np.multiply(shape_diff, 1 - batch_mask)
483
- if np.sum(shape_diff < 0) and not pad_fourier:
478
+ shape_diff = np.multiply(
479
+ np.subtract(target_shape, template_shape), 1 - batch_mask
480
+ )
481
+ if np.sum(shape_diff < 0):
484
482
  warnings.warn(
485
- "Target is larger than template and Fourier padding is turned off. "
486
- "This may lead to inaccurate results. Prefer swapping template and target, "
487
- "enable padding or turn off template centering."
483
+ "Template is larger than target and padding is turned off. Consider "
484
+ "swapping them or activate padding. Correcting the shift for now."
488
485
  )
489
- fourier_shift = np.subtract(fourier_shift, np.divide(shape_diff, 2))
490
- fourier_shift = fourier_shift.astype(int)
491
486
 
492
- return tuple(fast_shape), tuple(fast_ft_shape), tuple(fourier_shift)
487
+ shape_shift = np.divide(shape_diff, 2)
488
+ offset = np.mod(shape_diff, 2)
489
+ if pad_fourier:
490
+ offset = -np.subtract(
491
+ offset,
492
+ np.logical_and(np.mod(target_shape, 2) == 0, template_mod == 1),
493
+ )
494
+
495
+ shape_shift = np.add(shape_shift, offset)
496
+ fourier_shift = np.subtract(fourier_shift, shape_shift).astype(int)
497
+
498
+ fourier_shift = tuple(fourier_shift.astype(int))
499
+ return tuple(conv_shape), tuple(fast_shape), tuple(fast_ft_shape), fourier_shift
493
500
 
494
501
  def fourier_padding(self, pad_fourier: bool = False) -> Tuple[Tuple, Tuple, Tuple]:
495
502
  """
@@ -73,35 +73,47 @@ def _setup_template_filter_apply_target_filter(
73
73
  if not filter_template and not filter_target:
74
74
  return template_filter
75
75
 
76
- target_temp = be.topleft_pad(matching_data.target, fast_shape)
77
- target_temp_ft = be.zeros(fast_ft_shape, be._complex_dtype)
78
-
79
76
  inv_mask = be.subtract(1, be.to_backend_array(matching_data._batch_mask))
80
77
  filter_shape = be.multiply(be.to_backend_array(fast_ft_shape), inv_mask)
81
78
  filter_shape = tuple(int(x) if x != 0 else 1 for x in filter_shape)
82
-
83
79
  fast_shape = be.multiply(be.to_backend_array(fast_shape), inv_mask)
84
80
  fast_shape = tuple(int(x) for x in fast_shape if x != 0)
85
81
 
82
+ fastt_shape, fastt_ft_shape = fast_shape, filter_shape
83
+ if filter_template and not pad_template_filter:
84
+ # FFT shape acrobatics for faster filter application
85
+ _, fastt_shape, _, _ = matching_data._fourier_padding(
86
+ target_shape=be.to_numpy_array(matching_data._template.shape),
87
+ template_shape=be.to_numpy_array(
88
+ [1 for _ in matching_data._template.shape]
89
+ ),
90
+ pad_fourier=False,
91
+ )
92
+ matching_data.template = be.reverse(
93
+ be.topleft_pad(matching_data.template, fastt_shape)
94
+ )
95
+ matching_data.template_mask = be.reverse(
96
+ be.topleft_pad(matching_data.template_mask, fastt_shape)
97
+ )
98
+ matching_data._set_matching_dimension(
99
+ target_dims=matching_data._target_dims,
100
+ template_dims=matching_data._template_dims,
101
+ )
102
+ fastt_ft_shape = [int(x) for x in matching_data._output_template_shape]
103
+ fastt_ft_shape[-1] = fastt_ft_shape[-1] // 2 + 1
104
+
105
+ target_temp = be.topleft_pad(matching_data.target, fast_shape)
106
+ target_temp_ft = be.zeros(fast_ft_shape, be._complex_dtype)
86
107
  target_temp_ft = rfftn(target_temp, target_temp_ft)
87
108
  if filter_template:
88
- # TODO: Pad to fast shapes and adapt _setup_template_filtering accordingly
89
- template_fast_shape, template_filter_shape = fast_shape, filter_shape
90
- if not pad_template_filter:
91
- template_fast_shape = tuple(int(x) for x in matching_data._template.shape)
92
- template_filter_shape = [
93
- int(x) for x in matching_data._output_template_shape
94
- ]
95
- template_filter_shape[-1] = template_filter_shape[-1] // 2 + 1
96
-
97
109
  template_filter = matching_data.template_filter(
98
- shape=template_fast_shape,
110
+ shape=fastt_shape,
99
111
  return_real_fourier=True,
100
112
  shape_is_real_fourier=False,
101
113
  data_rfft=target_temp_ft,
102
114
  batch_dimension=matching_data._target_dims,
103
115
  )["data"]
104
- template_filter = be.reshape(template_filter, template_filter_shape)
116
+ template_filter = be.reshape(template_filter, fastt_ft_shape)
105
117
 
106
118
  if filter_target:
107
119
  target_filter = matching_data.target_filter(
@@ -212,9 +224,13 @@ def scan(
212
224
 
213
225
  """
214
226
  matching_data.to_backend()
215
- fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
216
- pad_fourier=pad_fourier
217
- )
227
+ (
228
+ conv_shape,
229
+ fast_shape,
230
+ fast_ft_shape,
231
+ fourier_shift,
232
+ ) = matching_data.fourier_padding(pad_fourier=pad_fourier)
233
+ template_shape = matching_data.template.shape
218
234
 
219
235
  rfftn, irfftn = be.build_fft(
220
236
  fast_shape=fast_shape,
@@ -256,7 +272,8 @@ def scan(
256
272
  "fourier_shift": fourier_shift,
257
273
  "convolution_mode": convmode,
258
274
  "targetshape": matching_data.target.shape,
259
- "templateshape": matching_data.template.shape,
275
+ "templateshape": template_shape,
276
+ "convolution_shape": conv_shape,
260
277
  "fast_shape": fast_shape,
261
278
  "indices": getattr(matching_data, "indices", None),
262
279
  "shared_memory_handler": shared_memory_handler,
tme/matching_scores.py CHANGED
@@ -86,7 +86,7 @@ def _setup_template_filtering(
86
86
  forward_ft_shape = template_shape
87
87
  inverse_ft_shape = template_filter.shape
88
88
 
89
- if rfftn is not None and irfftn is not None:
89
+ if (rfftn is not None and irfftn is not None) or shape_mismatch:
90
90
  rfftn, irfftn = be.build_fft(
91
91
  fast_shape=forward_ft_shape,
92
92
  fast_ft_shape=inverse_ft_shape,
@@ -109,7 +109,10 @@ def _setup_template_filtering(
109
109
 
110
110
  def _apply_filter_shape_mismatch(template, ft_temp, template_filter):
111
111
  _template[:] = template[real_subset]
112
- return _apply_template_filter(_template, _ft_temp, template_filter)
112
+ template[real_subset] = _apply_template_filter(
113
+ _template, _ft_temp, template_filter
114
+ )
115
+ return template
113
116
 
114
117
  return _apply_filter_shape_mismatch
115
118
 
tme/matching_utils.py CHANGED
@@ -467,6 +467,7 @@ def apply_convolution_mode(
467
467
  convolution_mode: str,
468
468
  s1: Tuple[int],
469
469
  s2: Tuple[int],
470
+ convolution_shape: Tuple[int] = None,
470
471
  mask_output: bool = False,
471
472
  ) -> BackendArray:
472
473
  """
@@ -490,6 +491,8 @@ def apply_convolution_mode(
490
491
  Tuple of integers corresponding to shape of convolution array 1.
491
492
  s2 : tuple of ints
492
493
  Tuple of integers corresponding to shape of convolution array 2.
494
+ convolution_shape : tuple of ints, optional
495
+ Size of the actually computed convolution. s1 + s2 - 1 by default.
493
496
  mask_output : bool, optional
494
497
  Whether to mask values outside of convolution_mode rather than
495
498
  removing them. Defaults to False.
@@ -500,7 +503,9 @@ def apply_convolution_mode(
500
503
  The array after applying the convolution mode.
501
504
  """
502
505
  # Remove padding to next fast Fourier length
503
- arr = arr[tuple(slice(s1[i] + s2[i] - 1) for i in range(len(s1)))]
506
+ if convolution_shape is None:
507
+ convolution_shape = [s1[i] + s2[i] - 1 for i in range(len(s1))]
508
+ arr = arr[tuple(slice(0, x) for x in convolution_shape)]
504
509
 
505
510
  if convolution_mode not in ("full", "same", "valid"):
506
511
  raise ValueError("Supported convolution_mode are 'full', 'same' and 'valid'.")
@@ -1220,7 +1225,7 @@ def scramble_phases(
1220
1225
  arr: NDArray,
1221
1226
  noise_proportion: float = 0.5,
1222
1227
  seed: int = 42,
1223
- normalize_power: bool = True,
1228
+ normalize_power: bool = False,
1224
1229
  ) -> NDArray:
1225
1230
  """
1226
1231
  Perform random phase scrambling of ``arr``.
tme/orientations.py CHANGED
@@ -6,6 +6,7 @@
6
6
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
  import re
9
+ import warnings
9
10
  from collections import deque
10
11
  from dataclasses import dataclass
11
12
  from string import ascii_lowercase
@@ -301,7 +302,7 @@ class Orientations:
301
302
  def _to_relion_star(
302
303
  self,
303
304
  filename: str,
304
- name_prefix: str = None,
305
+ name: str = None,
305
306
  ctf_image: str = None,
306
307
  sampling_rate: float = 1.0,
307
308
  subtomogram_size: int = 0,
@@ -313,8 +314,9 @@ class Orientations:
313
314
  ----------
314
315
  filename : str
315
316
  The name of the file to save the orientations.
316
- name_prefix : str, optional
317
- A prefix to add to the image names in the STAR file.
317
+ name : str or list of str, optional
318
+ Path to image file the orientation is in reference to. If name is a list
319
+ its assumed to correspond to _rlnImageName, otherwise _rlnMicrographName.
318
320
  ctf_image : str, optional
319
321
  Path to CTF or wedge mask RELION.
320
322
  sampling_rate : float, optional
@@ -352,6 +354,21 @@ class Orientations:
352
354
  optics_header = "\n".join(optics_header)
353
355
  optics_data = "\t".join(optics_data)
354
356
 
357
+ if name is None:
358
+ name = ""
359
+ warnings.warn(
360
+ "Consider specifying the name argument. A single string will be "
361
+ "interpreted as path to the original micrograph, a list of strings "
362
+ "as path to individual subsets."
363
+ )
364
+
365
+ name_reference = "_rlnImageName"
366
+ if isinstance(name, str):
367
+ name = [
368
+ name,
369
+ ] * self.translations.shape[0]
370
+ name_reference = "_rlnMicrographName"
371
+
355
372
  header = [
356
373
  "data_particles",
357
374
  "",
@@ -359,7 +376,7 @@ class Orientations:
359
376
  "_rlnCoordinateX",
360
377
  "_rlnCoordinateY",
361
378
  "_rlnCoordinateZ",
362
- "_rlnImageName",
379
+ name_reference,
363
380
  "_rlnAngleRot",
364
381
  "_rlnAngleTilt",
365
382
  "_rlnAnglePsi",
@@ -371,8 +388,6 @@ class Orientations:
371
388
  ctf_image = "" if ctf_image is None else f"\t{ctf_image}"
372
389
 
373
390
  header = "\n".join(header)
374
- name_prefix = "" if name_prefix is None else name_prefix
375
-
376
391
  with open(filename, mode="w", encoding="utf-8") as ofile:
377
392
  _ = ofile.write(f"{optics_header}\n")
378
393
  _ = ofile.write(f"{optics_data}\n")
@@ -387,9 +402,8 @@ class Orientations:
387
402
 
388
403
  translation_string = "\t".join([str(x) for x in translation][::-1])
389
404
  angle_string = "\t".join([str(x) for x in rotation])
390
- name = f"{name_prefix}_{index}.mrc"
391
405
  _ = ofile.write(
392
- f"{translation_string}\t{name}\t{angle_string}\t1{ctf_image}\n"
406
+ f"{translation_string}\t{name[index]}\t{angle_string}\t1{ctf_image}\n"
393
407
  )
394
408
 
395
409
  return None
@@ -584,7 +598,10 @@ class Orientations:
584
598
  cls, filename: str, delimiter: str = None
585
599
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
586
600
  ret = cls._parse_star(filename=filename, delimiter=delimiter)
587
- ret = ret["data_particles"]
601
+
602
+ ret = ret.get("data_particles", None)
603
+ if ret is None:
604
+ raise ValueError(f"No data_particles section found in {filename}.")
588
605
 
589
606
  translation = np.vstack(
590
607
  (ret["_rlnCoordinateZ"], ret["_rlnCoordinateY"], ret["_rlnCoordinateX"])
@@ -17,15 +17,18 @@ class ComposableFilter(ABC):
17
17
  @abstractmethod
18
18
  def __call__(self, *args, **kwargs) -> Dict:
19
19
  """
20
- Parameters:
21
- -----------
20
+
21
+ Parameters
22
+ ----------
23
+
22
24
  *args : tuple
23
25
  Variable length argument list.
24
26
  **kwargs : dict
25
27
  Arbitrary keyword arguments.
26
28
 
27
- Returns:
28
- --------
29
+ Returns
30
+ -------
31
+
29
32
  Dict
30
33
  A dictionary representing the result of the filtering operation.
31
34
  """
@@ -673,6 +673,9 @@ class WedgeReconstructed:
673
673
  omit_negative_frequencies=return_real_fourier,
674
674
  )
675
675
 
676
+ if not weight_wedge:
677
+ ret = (ret > 0) * 1.0
678
+
676
679
  return ret
677
680
 
678
681
 
@@ -790,16 +793,6 @@ class CTF:
790
793
  def __post_init__(self):
791
794
  self.defocus_angle = np.radians(self.defocus_angle)
792
795
 
793
- kwargs = {
794
- "defocus_x": self.defocus_x,
795
- "defocus_y": self.defocus_y,
796
- "spherical_aberration": self.spherical_aberration,
797
- }
798
- kwargs = {k: v for k, v in kwargs.items() if v is not None}
799
- self._update_parameters(
800
- electron_wavelength=self._compute_electron_wavelength(), **kwargs
801
- )
802
-
803
796
  def _compute_electron_wavelength(self, acceleration_voltage: int = None):
804
797
  """Computes the wavelength of an electron in angstrom."""
805
798
 
@@ -818,28 +811,10 @@ class CTF:
818
811
  electron_wavelength = np.divide(
819
812
  planck_constant * light_velocity, np.sqrt(denominator)
820
813
  )
814
+ # Convert to Ångstrom
821
815
  electron_wavelength *= 1e10
822
816
  return electron_wavelength
823
817
 
824
- def _update_parameters(self, **kwargs):
825
- """Update multiple parameters of the CTF instance."""
826
- voxel_based = [
827
- "electron_wavelength",
828
- "spherical_aberration",
829
- "defocus_x",
830
- "defocus_y",
831
- ]
832
- if "sampling_rate" in kwargs:
833
- self.sampling_rate = kwargs["sampling_rate"]
834
-
835
- if "acceleration_voltage" in kwargs:
836
- kwargs["electron_wavelength"] = self._compute_electron_wavelength()
837
-
838
- for key, value in kwargs.items():
839
- if key in voxel_based and value is not None:
840
- value = np.divide(value, np.max(self.sampling_rate))
841
- setattr(self, key, value)
842
-
843
818
  def __call__(self, **kwargs) -> NDArray:
844
819
  func_args = vars(self).copy()
845
820
  func_args.update(kwargs)
@@ -865,7 +840,6 @@ class CTF:
865
840
  shape: Tuple[int],
866
841
  defocus_x: Tuple[float],
867
842
  angles: Tuple[float],
868
- electron_wavelength: float = None,
869
843
  opening_axis: int = None,
870
844
  tilt_axis: int = None,
871
845
  amplitude_contrast: float = 0.07,
@@ -891,8 +865,6 @@ class CTF:
891
865
  The defocus value in x direction.
892
866
  angles : tuple of float
893
867
  The tilt angles.
894
- electron_wavelength : float, optional
895
- The electron wavelength, defaults to None.
896
868
  opening_axis : int, optional
897
869
  The axis around which the wedge is opened, defaults to None.
898
870
  tilt_axis : int, optional
@@ -939,9 +911,13 @@ class CTF:
939
911
  correct_defocus_gradient &= tilt_axis is not None
940
912
  correct_defocus_gradient &= opening_axis is not None
941
913
 
914
+ spherical_aberration /= sampling_rate
915
+ electron_wavelength = self._compute_electron_wavelength() / sampling_rate
942
916
  for index, angle in enumerate(angles):
943
917
  defocus_x, defocus_y = defoci_x[index], defoci_y[index]
944
918
 
919
+ defocus_x = defocus_x / sampling_rate if defocus_x is not None else None
920
+ defocus_y = defocus_y / sampling_rate if defocus_y is not None else None
945
921
  if correct_defocus_gradient or defocus_y is not None:
946
922
  grid = fftfreqn(
947
923
  shape=shape,
@@ -978,6 +954,7 @@ class CTF:
978
954
  angle=angle,
979
955
  sampling_rate=1,
980
956
  )
957
+ frequency_mask = frequency_grid < 0.5
981
958
  np.square(frequency_grid, out=frequency_grid)
982
959
 
983
960
  electron_aberration = spherical_aberration * electron_wavelength**2
@@ -993,6 +970,7 @@ class CTF:
993
970
  )
994
971
  )
995
972
  np.sin(-chi, out=chi)
973
+ np.multiply(chi, frequency_mask, out=chi)
996
974
  stack[index] = chi
997
975
 
998
976
  # Avoid contrast inversion
File without changes
File without changes