pytme 0.3.1.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2.dev0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. pytme-0.3.2.dev0.data/scripts/estimate_ram_usage.py +97 -0
  2. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/match_template.py +213 -196
  3. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/postprocess.py +40 -78
  4. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/preprocess.py +4 -5
  5. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/preprocessor_gui.py +50 -103
  6. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/pytme_runner.py +46 -69
  7. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/METADATA +2 -1
  8. pytme-0.3.2.dev0.dist-info/RECORD +136 -0
  9. scripts/estimate_ram_usage.py +97 -0
  10. scripts/match_template.py +213 -196
  11. scripts/match_template_devel.py +1339 -0
  12. scripts/postprocess.py +40 -78
  13. scripts/preprocess.py +4 -5
  14. scripts/preprocessor_gui.py +50 -103
  15. scripts/pytme_runner.py +46 -69
  16. scripts/refine_matches.py +5 -7
  17. tests/preprocessing/test_compose.py +31 -30
  18. tests/preprocessing/test_frequency_filters.py +17 -32
  19. tests/preprocessing/test_preprocessor.py +0 -19
  20. tests/preprocessing/test_utils.py +13 -1
  21. tests/test_analyzer.py +2 -10
  22. tests/test_backends.py +47 -18
  23. tests/test_density.py +72 -13
  24. tests/test_extensions.py +1 -0
  25. tests/test_matching_cli.py +23 -9
  26. tests/test_matching_exhaustive.py +5 -5
  27. tests/test_matching_utils.py +3 -3
  28. tests/test_rotations.py +13 -23
  29. tests/test_structure.py +1 -7
  30. tme/__version__.py +1 -1
  31. tme/analyzer/aggregation.py +47 -16
  32. tme/analyzer/base.py +34 -0
  33. tme/analyzer/peaks.py +26 -13
  34. tme/analyzer/proxy.py +14 -0
  35. tme/backends/_jax_utils.py +124 -71
  36. tme/backends/cupy_backend.py +6 -19
  37. tme/backends/jax_backend.py +110 -105
  38. tme/backends/matching_backend.py +0 -17
  39. tme/backends/mlx_backend.py +0 -29
  40. tme/backends/npfftw_backend.py +100 -97
  41. tme/backends/pytorch_backend.py +65 -78
  42. tme/cli.py +2 -2
  43. tme/density.py +102 -58
  44. tme/extensions.cpython-311-darwin.so +0 -0
  45. tme/filters/_utils.py +52 -24
  46. tme/filters/bandpass.py +99 -105
  47. tme/filters/compose.py +133 -39
  48. tme/filters/ctf.py +51 -102
  49. tme/filters/reconstruction.py +67 -122
  50. tme/filters/wedge.py +296 -325
  51. tme/filters/whitening.py +39 -75
  52. tme/mask.py +2 -2
  53. tme/matching_data.py +87 -15
  54. tme/matching_exhaustive.py +70 -120
  55. tme/matching_optimization.py +9 -63
  56. tme/matching_scores.py +261 -100
  57. tme/matching_utils.py +150 -91
  58. tme/memory.py +1 -0
  59. tme/orientations.py +28 -8
  60. tme/preprocessor.py +0 -239
  61. tme/rotations.py +102 -70
  62. tme/structure.py +601 -631
  63. tme/types.py +1 -0
  64. pytme-0.3.1.post1.dist-info/RECORD +0 -133
  65. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/estimate_memory_usage.py +0 -0
  66. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/WHEEL +0 -0
  67. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/entry_points.txt +0 -0
  68. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/licenses/LICENSE +0 -0
  69. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/top_level.txt +0 -0
tme/matching_utils.py CHANGED
@@ -9,7 +9,6 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
9
9
  import os
10
10
  import pickle
11
11
  from shutil import move
12
- from joblib import Parallel
13
12
  from tempfile import mkstemp
14
13
  from itertools import product
15
14
  from gzip import open as gzip_open
@@ -17,11 +16,29 @@ from typing import Tuple, Dict, Callable
17
16
  from concurrent.futures import ThreadPoolExecutor
18
17
 
19
18
  import numpy as np
20
- from scipy.spatial import ConvexHull
21
19
 
22
20
  from .backends import backend as be
23
21
  from .memory import estimate_memory_usage
24
- from .types import NDArray, BackendArray
22
+ from .types import NDArray, BackendArray, MatchingData
23
+
24
+
25
+ def copy_docstring(source_func, append: bool = True):
26
+ """Decorator to copy docstring from source function."""
27
+
28
+ def decorator(target_func):
29
+ base_doc = source_func.__doc__ or ""
30
+ if append and target_func.__doc__:
31
+ target_func.__doc__ = base_doc + "\n\n" + target_func.__doc__
32
+ else:
33
+ target_func.__doc__ = base_doc
34
+ return target_func
35
+
36
+ return decorator
37
+
38
+
39
+ def to_padded(buffer, data, unpadded_slice):
40
+ buffer = be.fill(buffer, 0)
41
+ return be.at(buffer, unpadded_slice, data)
25
42
 
26
43
 
27
44
  def identity(arr, *args, **kwargs):
@@ -54,7 +71,7 @@ def conditional_execute(
54
71
  return func if execute_operation else alt_func
55
72
 
56
73
 
57
- def normalize_template(
74
+ def standardize(
58
75
  template: BackendArray, mask: BackendArray, n_observations: float, axis=None
59
76
  ) -> BackendArray:
60
77
  """
@@ -95,12 +112,13 @@ def normalize_template(
95
112
  return be.multiply(template, mask, out=template)
96
113
 
97
114
 
98
- def _normalize_template_overflow_safe(
115
+ def _standardize_safe(
99
116
  template: BackendArray, mask: BackendArray, n_observations: float, axis=None
100
117
  ) -> BackendArray:
118
+ """Overflow-safe version of standardize using higher precision arithmetic."""
101
119
  _template = be.astype(template, be._overflow_safe_dtype)
102
120
  _mask = be.astype(mask, be._overflow_safe_dtype)
103
- normalize_template(
121
+ standardize(
104
122
  template=_template, mask=_mask, n_observations=n_observations, axis=axis
105
123
  )
106
124
  template[:] = be.astype(_template, template.dtype)
@@ -572,18 +590,20 @@ def split_shape(
572
590
  return splits
573
591
 
574
592
 
575
- def rigid_transform(
593
+ def _rigid_transform(
576
594
  coordinates: NDArray,
577
595
  rotation_matrix: NDArray,
578
596
  out: NDArray,
579
597
  translation: NDArray,
580
- use_geometric_center: bool = False,
581
598
  coordinates_mask: NDArray = None,
582
599
  out_mask: NDArray = None,
583
600
  center: NDArray = None,
601
+ **kwargs,
584
602
  ) -> None:
585
603
  """
586
- Apply a rigid transformation (rotation and translation) to given coordinates.
604
+ Apply a rigid transformation to given coordinates as
605
+
606
+ rotation_matrix.T @ coordinates + translation
587
607
 
588
608
  Parameters
589
609
  ----------
@@ -599,39 +619,25 @@ def rigid_transform(
599
619
  An array representing the mask for the coordinates (d,t).
600
620
  out_mask : NDArray, optional
601
621
  The output array to store the transformed coordinates mask (d,t).
602
- use_geometric_center : bool, optional
603
- Whether to use geometric or coordinate center.
622
+ center : NDArray, optional
623
+ Coordinate center, defaults to the average along each axis.
604
624
  """
605
- coordinate_dtype = coordinates.dtype
606
- center = coordinates.mean(axis=1) if center is None else center
607
- if not use_geometric_center:
608
- coordinates = coordinates - center[:, None]
609
-
610
- np.matmul(rotation_matrix, coordinates, out=out)
611
- if use_geometric_center:
612
- axis_max, axis_min = out.max(axis=1), out.min(axis=1)
613
- axis_difference = axis_max - axis_min
614
- translation = np.add(translation, center - axis_max + (axis_difference // 2))
615
- else:
616
- translation = np.add(translation, np.subtract(center, out.mean(axis=1)))
625
+ if center is None:
626
+ center = coordinates.mean(axis=1)
617
627
 
618
- out += translation[:, None]
619
- if coordinates_mask is not None and out_mask is not None:
620
- if not use_geometric_center:
621
- coordinates_mask = coordinates_mask - center[:, None]
622
- np.matmul(rotation_matrix, coordinates_mask, out=out_mask)
623
- out_mask += translation[:, None]
628
+ coordinates = coordinates - center[:, None]
629
+ out = np.matmul(rotation_matrix.T, coordinates, out=out)
630
+ translation = np.add(translation, center)
624
631
 
625
- if not use_geometric_center and coordinate_dtype != out.dtype:
626
- np.subtract(out.mean(axis=1), out.astype(int).mean(axis=1), out=translation)
627
- out += translation[:, None]
632
+ out = np.add(out, translation[:, None], out=out)
633
+ if coordinates_mask is not None and out_mask is not None:
634
+ np.matmul(rotation_matrix.T, coordinates_mask, out=out_mask)
635
+ out_mask = np.add(out_mask, translation[:, None], out=out_mask)
628
636
 
629
637
 
630
- def minimum_enclosing_box(
631
- coordinates: NDArray, margin: NDArray = None, use_geometric_center: bool = False
632
- ) -> Tuple[int]:
638
+ def minimum_enclosing_box(coordinates: NDArray, **kwargs) -> Tuple[int, ...]:
633
639
  """
634
- Computes the minimal enclosing box around coordinates with margin.
640
+ Computes the minimal enclosing box around coordinates.
635
641
 
636
642
  Parameters
637
643
  ----------
@@ -639,35 +645,30 @@ def minimum_enclosing_box(
639
645
  Coordinates of shape (d,n) to compute the enclosing box of.
640
646
  margin : NDArray, optional
641
647
  Box margin, zero by default.
648
+
649
+ .. deprecated:: 0.3.2
650
+
651
+ Boxed are returned without margin.
652
+
642
653
  use_geometric_center : bool, optional
643
654
  Whether box accommodates the geometric or coordinate center, False by default.
644
655
 
656
+ .. deprecated:: 0.3.2
657
+
658
+ Boxes always accomodate the coordinate center
659
+
645
660
  Returns
646
661
  -------
647
- tuple of ints
648
- Minimum enclosing box shape.
662
+ tuple of int
663
+ Minimum enclosing box.
649
664
  """
650
- from .extensions import max_euclidean_distance
665
+ coordinates = np.asarray(coordinates).T
666
+ coordinates = coordinates - coordinates.min(axis=0)
667
+ coordinates = coordinates - coordinates.mean(axis=0)
651
668
 
652
- point_cloud = np.asarray(coordinates)
653
- dim = point_cloud.shape[0]
654
- point_cloud = point_cloud - point_cloud.min(axis=1)[:, None]
655
-
656
- margin = np.zeros(dim) if margin is None else margin
657
- margin = np.asarray(margin).astype(int)
658
-
659
- norm_cloud = point_cloud - point_cloud.mean(axis=1)[:, None]
660
669
  # Adding one avoids clipping during scipy.ndimage.affine_transform
661
- shape = np.repeat(
662
- np.ceil(2 * np.linalg.norm(norm_cloud, axis=0).max()) + 1, dim
663
- ).astype(int)
664
- if use_geometric_center:
665
- hull = ConvexHull(point_cloud.T)
666
- distance, _ = max_euclidean_distance(point_cloud[:, hull.vertices].T)
667
- distance += np.linalg.norm(np.ones(dim))
668
- shape = np.repeat(np.rint(distance).astype(int), dim)
669
-
670
- return shape
670
+ box_size = int(np.ceil(2 * np.linalg.norm(coordinates, axis=1).max()) + 1)
671
+ return tuple(box_size for _ in range(coordinates.shape[1]))
671
672
 
672
673
 
673
674
  def scramble_phases(
@@ -684,15 +685,13 @@ def scramble_phases(
684
685
  Proportion of scrambled phases, 1.0 by default.
685
686
  seed : int, optional
686
687
  The seed for the random phase scrambling, 42 by default.
687
- normalize_power : bool, optional
688
- Return value has same sum of squares as ``arr``.
689
688
 
690
689
  Returns
691
690
  -------
692
691
  NDArray
693
692
  Phase scrambled version of ``arr``.
694
693
  """
695
- from tme.filters._utils import fftfreqn
694
+ from .filters._utils import fftfreqn
696
695
 
697
696
  np.random.seed(seed)
698
697
  noise_proportion = max(min(noise_proportion, 1), 0)
@@ -700,9 +699,11 @@ def scramble_phases(
700
699
  arr_fft = np.fft.fftn(arr)
701
700
  amp, ph = np.abs(arr_fft), np.angle(arr_fft)
702
701
 
703
- # Scrambling up to nyquist gives more uniform noise distribution
704
- mask = np.fft.ifftshift(
705
- fftfreqn(arr_fft.shape, sampling_rate=1, compute_euclidean_norm=True) <= 0.5
702
+ mask = (
703
+ fftfreqn(
704
+ arr_fft.shape, sampling_rate=1, compute_euclidean_norm=True, fftshift=False
705
+ )
706
+ <= 0.5
706
707
  )
707
708
 
708
709
  ph_noise = np.random.permutation(ph[mask])
@@ -811,38 +812,96 @@ def create_mask(mask_type: str, sigma_decay: float = 0, **kwargs) -> NDArray:
811
812
  return mask
812
813
 
813
814
 
814
- class TqdmParallel(Parallel):
815
- """
816
- A minimal Parallel implementation using tqdm for progress reporting.
817
-
818
- Parameters:
819
- -----------
820
- tqdm_args : dict, optional
821
- Dictionary of arguments passed to tqdm.tqdm
822
- *args, **kwargs:
823
- Arguments to pass to joblib.Parallel
824
- """
815
+ def setup_filter(
816
+ matching_data: MatchingData,
817
+ fast_shape: Tuple[int],
818
+ fast_ft_shape: Tuple[int],
819
+ pad_template_filter: bool = False,
820
+ apply_target_filter: bool = False,
821
+ ):
822
+ from .filters import Compose
823
+
824
+ backend_arr = type(be.zeros((1), dtype=be._float_dtype))
825
+ template_filter = be.full(shape=(1,), fill_value=1, dtype=be._float_dtype)
826
+ target_filter = be.full(shape=(1,), fill_value=1, dtype=be._float_dtype)
827
+ if isinstance(matching_data.template_filter, backend_arr):
828
+ template_filter = matching_data.template_filter
825
829
 
826
- def __init__(self, tqdm_args: Dict = {}, *args, **kwargs):
827
- from tqdm import tqdm
830
+ if isinstance(matching_data.target_filter, backend_arr):
831
+ target_filter = matching_data.target_filter
828
832
 
829
- super().__init__(*args, **kwargs)
830
- self.pbar = tqdm(**tqdm_args)
833
+ filter_template = isinstance(matching_data.template_filter, Compose)
834
+ filter_target = isinstance(matching_data.target_filter, Compose)
831
835
 
832
- def __call__(self, iterable, *args, **kwargs):
833
- self.n_tasks = len(iterable) if hasattr(iterable, "__len__") else None
834
- return super().__call__(iterable, *args, **kwargs)
836
+ # For now assume user-supplied template_filter is correctly padded
837
+ if filter_target is None and target_filter is None:
838
+ return template_filter
835
839
 
836
- def print_progress(self):
837
- if self.n_tasks is None:
838
- return super().print_progress()
840
+ batch_mask = matching_data._batch_mask
841
+ real_shape = matching_data._batch_shape(fast_shape, batch_mask, keepdims=False)
842
+ cmpl_shape = matching_data._batch_shape(fast_ft_shape, batch_mask, keepdims=True)
839
843
 
840
- if self.n_tasks != self.pbar.total:
841
- self.pbar.total = self.n_tasks
842
- self.pbar.refresh()
844
+ real_tmpl_shape, cmpl_tmpl_shape = real_shape, cmpl_shape
845
+ if not pad_template_filter:
846
+ shape = matching_data._output_template_shape
843
847
 
844
- self.pbar.n = self.n_completed_tasks
845
- self.pbar.refresh()
848
+ real_tmpl_shape = matching_data._batch_shape(shape, batch_mask, keepdims=False)
849
+ cmpl_tmpl_shape = matching_data._batch_shape(shape, batch_mask, keepdims=True)
850
+ cmpl_tmpl_shape = list(cmpl_tmpl_shape)
851
+ cmpl_tmpl_shape[-1] = cmpl_tmpl_shape[-1] // 2 + 1
846
852
 
847
- if self.n_completed_tasks >= self.n_tasks:
848
- self.pbar.close()
853
+ cmpl_shape = tuple(
854
+ -1 if y else x for x, y in zip(cmpl_shape, matching_data._target_batch)
855
+ )
856
+ cmpl_tmpl_shape = list(
857
+ -1 if y else x for x, y in zip(cmpl_tmpl_shape, matching_data._template_batch)
858
+ )
859
+
860
+ # We can have one flexible dimension and this makes projection matching easier
861
+ if not any(matching_data._template_batch):
862
+ cmpl_tmpl_shape[0] = -1
863
+
864
+ # Avoid invalidating the meaning of some filters on padded batch dimensions
865
+ target_shape = np.maximum(
866
+ np.multiply(fast_shape, tuple(1 - x for x in matching_data._target_batch)),
867
+ matching_data.target.shape,
868
+ )
869
+ target_shape = tuple(int(x) for x in target_shape)
870
+ target_temp = be.topleft_pad(matching_data.target, target_shape)
871
+ shape = matching_data._batch_shape(
872
+ target_temp.shape, matching_data._target_batch, keepdims=False
873
+ )
874
+ axes = matching_data._batch_axis(matching_data._target_batch)
875
+ target_temp_ft = be.rfftn(target_temp, s=shape, axes=axes)
876
+
877
+ # Setup composable filters
878
+ filter_kwargs = {
879
+ "return_real_fourier": True,
880
+ "shape_is_real_fourier": False,
881
+ "data_rfft": target_temp_ft,
882
+ "axes": matching_data._target_dim,
883
+ }
884
+ if filter_template:
885
+ template_filter = matching_data.template_filter(
886
+ shape=real_tmpl_shape, **filter_kwargs
887
+ )["data"]
888
+ template_filter = be.reshape(template_filter, cmpl_tmpl_shape)
889
+ template_filter = be.astype(
890
+ be.to_backend_array(template_filter), be._float_dtype
891
+ )
892
+ template_filter = be.at(template_filter, ((0,) * template_filter.ndim), 0)
893
+
894
+ if filter_target:
895
+ target_filter = matching_data.target_filter(
896
+ shape=real_shape, weight_type=None, **filter_kwargs
897
+ )["data"]
898
+ target_filter = be.reshape(target_filter, cmpl_shape)
899
+ target_filter = be.astype(be.to_backend_array(target_filter), be._float_dtype)
900
+ target_filter = be.at(target_filter, ((0,) * target_filter.ndim), 0)
901
+
902
+ if apply_target_filter and filter_target:
903
+ target_temp_ft = be.multiply(target_temp_ft, target_filter, out=target_temp_ft)
904
+ target_temp = be.irfftn(target_temp_ft, s=shape, axes=axes)
905
+ matching_data._target = be.topleft_pad(target_temp, matching_data.target.shape)
906
+
907
+ return template_filter, target_filter
tme/memory.py CHANGED
@@ -236,6 +236,7 @@ MATCHING_MEMORY_REGISTRY = {
236
236
  "CC": CCMemoryUsage,
237
237
  "LCC": LCCMemoryUsage,
238
238
  "CORR": CORRMemoryUsage,
239
+ "NCC": CORRMemoryUsage,
239
240
  "CAM": CAMMemoryUsage,
240
241
  "MCC": MCCMemoryUsage,
241
242
  "FLCSphericalMask": FLCSphericalMaskMemoryUsage,
tme/orientations.py CHANGED
@@ -13,6 +13,7 @@ from string import ascii_lowercase, ascii_uppercase
13
13
  import numpy as np
14
14
 
15
15
  from .parser import StarParser
16
+ from .__version__ import __version__
16
17
  from .matching_utils import compute_extraction_box
17
18
 
18
19
  # Exceeds available numpy dimensions for default installations
@@ -341,6 +342,8 @@ class Orientations:
341
342
  header.append("_pytmeScore")
342
343
  header = "\n".join(header)
343
344
  with open(filename, mode="w", encoding="utf-8") as ofile:
345
+ _ = ofile.write(f"# Created using pytme (version {__version__}).\n\n")
346
+
344
347
  if version is not None:
345
348
  _ = ofile.write(f"{version.strip()}\n\n")
346
349
 
@@ -494,21 +497,38 @@ class Orientations:
494
497
 
495
498
  @classmethod
496
499
  def _from_star(
497
- cls, filename: str, delimiter: str = "\t"
500
+ cls, filename: str, delimiter: str = None
498
501
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
499
502
  parser = StarParser(filename, delimiter=delimiter)
500
503
 
501
- ret = parser.get("data_particles", None)
502
- if ret is None:
503
- ret = parser.get("data_", None)
504
+ keyword_order = ("data_particles", "particles", "data")
505
+ for keyword in keyword_order:
506
+ ret = parser.get(keyword, None)
507
+ if ret is None:
508
+ ret = parser.get(f"{keyword}_", None)
509
+ if ret is not None:
510
+ break
504
511
 
505
512
  if ret is None:
506
- raise ValueError(f"No data_particles section found in {filename}.")
513
+ raise ValueError(
514
+ f"Could not find either {keyword_order} section found in {filename}."
515
+ )
507
516
 
508
- translation = np.vstack(
509
- (ret["_rlnCoordinateX"], ret["_rlnCoordinateY"], ret["_rlnCoordinateZ"])
517
+ keys_v4 = ("_rlnCoordinateX", "_rlnCoordinateY", "_rlnCoordinateZ")
518
+ keys_v5 = (
519
+ "_rlnCenteredCoordinateXAngst",
520
+ "_rlnCenteredCoordinateYAngst",
521
+ "_rlnCenteredCoordinateZAngst",
510
522
  )
511
- translation = translation.astype(np.float32).T
523
+ if all(key in ret for key in keys_v4):
524
+ keys = keys_v4
525
+ elif all(key in ret for key in keys_v5):
526
+ keys = keys_v5
527
+ else:
528
+ raise ValueError(
529
+ f"File format not recognized. Need either {keys_v4} or {keys_v5}."
530
+ )
531
+ translation = np.vstack(tuple(ret[x] for x in keys)).astype(np.float32).T
512
532
 
513
533
  default_angle = np.zeros(translation.shape[0], dtype=np.float32)
514
534
  for x in ("_rlnAngleRot", "_rlnAngleTilt", "_rlnAnglePsi"):
tme/preprocessor.py CHANGED
@@ -520,33 +520,6 @@ class Preprocessor:
520
520
 
521
521
  return template
522
522
 
523
- def mipmap_filter(self, template: NDArray, level: int) -> NDArray:
524
- """
525
- Perform mip map antialiasing filtering.
526
-
527
- Parameters
528
- ----------
529
- template : NDArray
530
- The input atomic structure map.
531
- level : int
532
- Pyramid layer. Resolution decreases cubically with level.
533
-
534
- Returns
535
- -------
536
- NDArray
537
- Simulated electron densities.
538
- """
539
- array = template.copy()
540
- interpolation_box = array.shape
541
-
542
- for k in range(template.ndim):
543
- array = ndimage.decimate(array, q=level, axis=k)
544
-
545
- template = ndimage.zoom(array, np.divide(template.shape, array.shape))
546
- template = self.interpolate_box(box=interpolation_box, arr=template)
547
-
548
- return template
549
-
550
523
  def interpolate_box(
551
524
  self, arr: NDArray, box: Tuple[int], kind: str = "nearest"
552
525
  ) -> NDArray:
@@ -588,218 +561,6 @@ class Preprocessor:
588
561
 
589
562
  return arr
590
563
 
591
- def bandpass_filter(
592
- self,
593
- template: NDArray,
594
- lowpass: float,
595
- highpass: float,
596
- sampling_rate: NDArray = None,
597
- gaussian_sigma: float = 0.0,
598
- ) -> NDArray:
599
- """
600
- Apply a band-pass filter on the provided template, using a
601
- Butterworth approximation.
602
-
603
- Parameters
604
- ----------
605
- template : NDArray
606
- The input numpy array on which the band-pass filter should be applied.
607
- lowpass : float
608
- The lower boundary of the frequency range to be preserved. Lower values will
609
- retain broader, more global features.
610
- highpass : float
611
- The upper boundary of the frequency range to be preserved. Higher values
612
- will emphasize finer details and potentially noise.
613
- sampling_rate : NDarray, optional
614
- The sampling rate along each dimension.
615
- gaussian_sigma : float, optional
616
- Sigma value for the gaussian smoothing to be applied to the filter.
617
-
618
- Returns
619
- -------
620
- NDArray
621
- Bandpass filtered numpy array.
622
- """
623
- bpf = self.bandpass_mask(
624
- shape=template.shape,
625
- lowpass=lowpass,
626
- highpass=highpass,
627
- sampling_rate=sampling_rate,
628
- gaussian_sigma=gaussian_sigma,
629
- omit_negative_frequencies=False,
630
- )
631
-
632
- fft_data = np.fft.fftn(template)
633
- np.multiply(fft_data, bpf, out=fft_data)
634
- ret = np.real(np.fft.ifftn(fft_data))
635
- return ret
636
-
637
- def bandpass_mask(
638
- self,
639
- shape: Tuple[int],
640
- lowpass: float,
641
- highpass: float,
642
- sampling_rate: NDArray = None,
643
- gaussian_sigma: float = 0.0,
644
- omit_negative_frequencies: bool = True,
645
- ) -> NDArray:
646
- """
647
- Compute an approximate Butterworth bundpass filter. The returned filter
648
- has it's DC component at the origin.
649
-
650
- Parameters
651
- ----------
652
- shape : tuple of ints
653
- Shape of the returned bandpass filter.
654
- lowpass : float
655
- The lower boundary of the frequency range to be preserved. Lower values will
656
- retain broader, more global features.
657
- maximum_frequency : float
658
- The upper boundary of the frequency range to be preserved. Higher values
659
- will emphasize finer details and potentially noise.
660
- sampling_rate : NDarray, optional
661
- The sampling rate along each dimension.
662
- gaussian_sigma : float, optional
663
- Sigma value for the gaussian smoothing to be applied to the filter.
664
- omit_negative_frequencies : bool, optional
665
- Whether the wedge mask should omit negative frequencies, i.e. be
666
- applicable to non hermitian-symmetric fourier transforms.
667
-
668
- Returns
669
- -------
670
- NDArray
671
- Bandpass filtered.
672
- """
673
- from .filters import BandPassReconstructed
674
-
675
- return BandPassReconstructed(
676
- sampling_rate=sampling_rate,
677
- lowpass=lowpass,
678
- highpass=highpass,
679
- return_real_fourier=omit_negative_frequencies,
680
- use_gaussian=gaussian_sigma == 0.0,
681
- )(shape=shape)["data"]
682
-
683
- def step_wedge_mask(
684
- self,
685
- shape: Tuple[int],
686
- tilt_angles: Tuple[float] = None,
687
- opening_axis: int = 0,
688
- tilt_axis: int = 2,
689
- weights: float = None,
690
- infinite_plane: bool = False,
691
- omit_negative_frequencies: bool = True,
692
- ) -> NDArray:
693
- """
694
- Create a wedge mask with the same shape as template by rotating a
695
- plane according to tilt angles. The DC component of the filter is at the origin.
696
-
697
- Parameters
698
- ----------
699
- tilt_angles : tuple of float
700
- Sequence of tilt angles.
701
- shape : Tuple of ints
702
- Shape of the output wedge array.
703
- tilt_axis : int, optional
704
- Axis that the plane is tilted over.
705
- - 0 for Z-axis
706
- - 1 for Y-axis
707
- - 2 for X-axis
708
- opening_axis : int, optional
709
- Axis running through the void defined by the wedge.
710
- - 0 for Z-axis
711
- - 1 for Y-axis
712
- - 2 for X-axis
713
- sigma : float, optional
714
- Standard deviation for Gaussian kernel used for smoothing the wedge.
715
- weights : float, tuple of float
716
- Weight of each element in the wedge. Defaults to one.
717
- omit_negative_frequencies : bool, optional
718
- Whether the wedge mask should omit negative frequencies, i.e. be
719
- applicable to symmetric Fourier transforms (see :obj:`numpy.fft.fftn`)
720
-
721
- Returns
722
- -------
723
- NDArray
724
- A numpy array containing the wedge mask.
725
-
726
- See Also
727
- --------
728
- :py:meth:`Preprocessor.continuous_wedge_mask`
729
- """
730
- from .filters import WedgeReconstructed
731
-
732
- return WedgeReconstructed(
733
- angles=tilt_angles,
734
- tilt_axis=tilt_axis,
735
- opening_axis=opening_axis,
736
- frequency_cutoff=None if infinite_plane else 0.5,
737
- create_continuous_wedge=False,
738
- weights=weights,
739
- weight_wedge=weights is not None,
740
- )(shape=shape, return_real_fourier=omit_negative_frequencies,)["data"]
741
-
742
- def continuous_wedge_mask(
743
- self,
744
- start_tilt: float,
745
- stop_tilt: float,
746
- shape: Tuple[int],
747
- opening_axis: int = 0,
748
- tilt_axis: int = 2,
749
- infinite_plane: bool = True,
750
- omit_negative_frequencies: bool = True,
751
- ) -> NDArray:
752
- """
753
- Generate a wedge in a given shape based on specified tilt angles and axis.
754
- The DC component of the filter is at the origin.
755
-
756
- Parameters
757
- ----------
758
- start_tilt : float
759
- Starting tilt angle in degrees, e.g. a stage tilt of 70 degrees
760
- would yield a start_tilt value of 70.
761
- stop_tilt : float
762
- Ending tilt angle in degrees, , e.g. a stage tilt of -70 degrees
763
- would yield a stop_tilt value of 70.
764
- tilt_axis : int
765
- Axis that the plane is tilted over.
766
- - 0 for Z-axis
767
- - 1 for Y-axis
768
- - 2 for X-axis
769
- opening_axis : int
770
- Axis running through the void defined by the wedge.
771
- - 0 for Z-axis
772
- - 1 for Y-axis
773
- - 2 for X-axis
774
- shape : Tuple of ints
775
- Shape of the output wedge array.
776
- omit_negative_frequencies : bool, optional
777
- Whether the wedge mask should omit negative frequencies, i.e. be
778
- applicable to symmetric Fourier transforms (see :obj:`numpy.fft.fftn`)
779
- infinite_plane : bool, optional
780
- Whether the plane should be considered to be larger than the shape. In this
781
- case the output wedge mask fill have no spheric component.
782
-
783
- Returns
784
- -------
785
- NDArray
786
- Array of the specified shape with the wedge created based on
787
- the tilt angles.
788
-
789
- See Also
790
- --------
791
- :py:meth:`Preprocessor.step_wedge_mask`
792
- """
793
- from .filters import WedgeReconstructed
794
-
795
- return WedgeReconstructed(
796
- angles=(start_tilt, stop_tilt),
797
- tilt_axis=tilt_axis,
798
- opening_axis=opening_axis,
799
- frequency_cutoff=None if infinite_plane else 0.5,
800
- create_continuous_wedge=True,
801
- )(shape=shape, return_real_fourier=omit_negative_frequencies)["data"]
802
-
803
564
 
804
565
  def window_kaiserb(width: int, beta: float = 3.2, order: int = 0) -> NDArray:
805
566
  """