pytme 0.2.9.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. pytme-0.2.9.post1.data/scripts/estimate_ram_usage.py → pytme-0.3b0.data/scripts/estimate_memory_usage.py +16 -33
  2. {pytme-0.2.9.post1.data → pytme-0.3b0.data}/scripts/match_template.py +224 -223
  3. {pytme-0.2.9.post1.data → pytme-0.3b0.data}/scripts/postprocess.py +283 -163
  4. {pytme-0.2.9.post1.data → pytme-0.3b0.data}/scripts/preprocess.py +11 -8
  5. {pytme-0.2.9.post1.data → pytme-0.3b0.data}/scripts/preprocessor_gui.py +10 -9
  6. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.dist-info}/METADATA +10 -9
  7. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.dist-info}/RECORD +61 -58
  8. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.dist-info}/entry_points.txt +1 -1
  9. scripts/{estimate_ram_usage.py → estimate_memory_usage.py} +16 -33
  10. scripts/extract_candidates.py +224 -0
  11. scripts/match_template.py +224 -223
  12. scripts/postprocess.py +283 -163
  13. scripts/preprocess.py +11 -8
  14. scripts/preprocessor_gui.py +10 -9
  15. scripts/refine_matches.py +626 -0
  16. tests/preprocessing/test_frequency_filters.py +9 -4
  17. tests/test_analyzer.py +143 -138
  18. tests/test_matching_cli.py +85 -30
  19. tests/test_matching_exhaustive.py +1 -2
  20. tests/test_matching_optimization.py +4 -9
  21. tests/test_orientations.py +0 -1
  22. tme/__version__.py +1 -1
  23. tme/analyzer/__init__.py +2 -0
  24. tme/analyzer/_utils.py +25 -17
  25. tme/analyzer/aggregation.py +384 -220
  26. tme/analyzer/base.py +138 -0
  27. tme/analyzer/peaks.py +150 -91
  28. tme/analyzer/proxy.py +122 -0
  29. tme/backends/__init__.py +4 -3
  30. tme/backends/_cupy_utils.py +25 -24
  31. tme/backends/_jax_utils.py +4 -3
  32. tme/backends/cupy_backend.py +4 -13
  33. tme/backends/jax_backend.py +6 -8
  34. tme/backends/matching_backend.py +4 -3
  35. tme/backends/mlx_backend.py +4 -3
  36. tme/backends/npfftw_backend.py +7 -5
  37. tme/backends/pytorch_backend.py +14 -4
  38. tme/cli.py +126 -0
  39. tme/density.py +4 -3
  40. tme/filters/__init__.py +1 -1
  41. tme/filters/_utils.py +4 -3
  42. tme/filters/bandpass.py +6 -4
  43. tme/filters/compose.py +5 -4
  44. tme/filters/ctf.py +426 -214
  45. tme/filters/reconstruction.py +58 -28
  46. tme/filters/wedge.py +139 -61
  47. tme/filters/whitening.py +36 -36
  48. tme/matching_data.py +4 -3
  49. tme/matching_exhaustive.py +17 -16
  50. tme/matching_optimization.py +5 -4
  51. tme/matching_scores.py +4 -3
  52. tme/matching_utils.py +41 -3
  53. tme/memory.py +4 -3
  54. tme/orientations.py +9 -6
  55. tme/parser.py +5 -4
  56. tme/preprocessor.py +4 -3
  57. tme/rotations.py +10 -7
  58. tme/structure.py +4 -3
  59. tests/data/Maps/.DS_Store +0 -0
  60. tests/data/Structures/.DS_Store +0 -0
  61. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.dist-info}/WHEEL +0 -0
  62. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.dist-info}/licenses/LICENSE +0 -0
  63. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,9 @@
1
- """ Defines filters on tomographic tilt series.
1
+ """
2
+ Defines filters on tomographic tilt series.
2
3
 
3
- Copyright (c) 2024 European Molecular Biology Laboratory
4
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  from typing import Tuple
@@ -15,27 +16,26 @@ from ..backends import backend as be
15
16
 
16
17
  from .compose import ComposableFilter
17
18
  from ..rotations import euler_to_rotationmatrix
18
- from ._utils import (
19
- crop_real_fourier,
20
- shift_fourier,
21
- create_reconstruction_filter,
22
- )
19
+ from ._utils import crop_real_fourier, shift_fourier, create_reconstruction_filter
23
20
 
24
21
  __all__ = ["ReconstructFromTilt"]
25
22
 
26
23
 
27
24
  @dataclass
28
25
  class ReconstructFromTilt(ComposableFilter):
29
- """Reconstruct a volume from a tilt series."""
26
+ """
27
+ Reconstruct a d+1 array from a d-dimensional input projection using weighted
28
+ backprojection (WBP).
29
+ """
30
30
 
31
31
  #: Shape of the reconstruction.
32
32
  shape: Tuple[int] = None
33
33
  #: Angle of each individual tilt.
34
34
  angles: Tuple[float] = None
35
- #: The axis around which the volume is opened.
36
- opening_axis: int = 0
37
- #: Axis the plane is tilted over.
38
- tilt_axis: int = 2
35
+ #: Projection axis, defaults to 2 (z).
36
+ opening_axis: int = 2
37
+ #: Tilt axis, defaults to 0 (x).
38
+ tilt_axis: int = 0
39
39
  #: Whether to return a share compliant with rfftn.
40
40
  return_real_fourier: bool = True
41
41
  #: Interpolation order used for rotation
@@ -43,19 +43,57 @@ class ReconstructFromTilt(ComposableFilter):
43
43
  #: Filter window applied during reconstruction.
44
44
  reconstruction_filter: str = None
45
45
 
46
- def __call__(self, **kwargs):
46
+ def __call__(self, return_real_fourier: bool = False, **kwargs):
47
+ """
48
+ Reconstruct a d+1 array from a d-dimensional input using WBP.
49
+
50
+ Parameters
51
+ ----------
52
+ shape : tuple of int
53
+ The shape of the reconstruction volume.
54
+ data : BackendArray
55
+ D-dimensional image stack with shape (n, ...)
56
+ angles : tuple of float
57
+ Angle of each individual tilt.
58
+ return_real_fourier : bool, optional
59
+ Return a shape compliant
60
+ return_real_fourier : tuple of int
61
+ Return a shape compliant with rfft, i.e., omit the negative frequencies
62
+ terms resulting in a return shape (*shape[:-1], shape[-1]//2+1). Defaults
63
+ to False.
64
+ reconstruction_filter : bool, optional
65
+ Filter window applied during reconstruction.
66
+ See :py:meth:`create_reconstruction_filter` for available options.
67
+ tilt_axis : int
68
+ Axis the plane is tilted over, defaults to 0 (x).
69
+ opening_axis : int
70
+ The projection axis, defaults to 2 (z).
71
+
72
+ Returns
73
+ -------
74
+ dict
75
+ data: BackendArray
76
+ The filter mask.
77
+ shape: tuple of ints
78
+ The requested filter shape
79
+ return_real_fourier: bool
80
+ Whether data is compliant with rfftn.
81
+ is_multiplicative_filter: bool
82
+ Whether the filter is multiplicative in Fourier space.
83
+ """
84
+
47
85
  func_args = vars(self).copy()
48
86
  func_args.update(kwargs)
49
87
 
50
88
  ret = self.reconstruct(**func_args)
51
89
 
90
+ if return_real_fourier:
91
+ ret = crop_real_fourier(ret)
92
+
52
93
  return {
53
94
  "data": ret,
54
- "shape": ret.shape,
55
- "shape_is_real_fourier": func_args["return_real_fourier"],
56
- "angles": func_args["angles"],
57
- "tilt_axis": func_args["tilt_axis"],
58
- "opening_axis": func_args["opening_axis"],
95
+ "shape": func_args["shape"],
96
+ "shape_is_real_fourier": return_real_fourier,
59
97
  "is_multiplicative_filter": False,
60
98
  }
61
99
 
@@ -67,7 +105,6 @@ class ReconstructFromTilt(ComposableFilter):
67
105
  opening_axis: int,
68
106
  tilt_axis: int,
69
107
  interpolation_order: int = 1,
70
- return_real_fourier: bool = True,
71
108
  reconstruction_filter: str = None,
72
109
  **kwargs,
73
110
  ):
@@ -88,8 +125,6 @@ class ReconstructFromTilt(ComposableFilter):
88
125
  Axis the plane is tilted over.
89
126
  interpolation_order : int, optional
90
127
  Interpolation order used for rotation, defaults to 1.
91
- return_real_fourier : bool, optional
92
- Whether to return a shape compliant with rfftn, defaults to True.
93
128
  reconstruction_filter : bool, optional
94
129
  Filter window applied during reconstruction.
95
130
  See :py:meth:`create_reconstruction_filter` for available options.
@@ -152,9 +187,4 @@ class ReconstructFromTilt(ComposableFilter):
152
187
  )
153
188
  volume = be.add(volume, volume_temp_rotated, out=volume)
154
189
 
155
- volume = shift_fourier(data=volume, shape_is_real_fourier=False)
156
-
157
- if return_real_fourier:
158
- volume = crop_real_fourier(volume)
159
-
160
- return volume
190
+ return shift_fourier(data=volume, shape_is_real_fourier=False)
tme/filters/wedge.py CHANGED
@@ -1,11 +1,13 @@
1
- """ Implements class Wedge and WedgeReconstructed to create Fourier
2
- filter representations.
1
+ """
2
+ Implements class Wedge and WedgeReconstructed to create Fourier
3
+ filter representations.
3
4
 
4
- Copyright (c) 2024 European Molecular Biology Laboratory
5
+ Copyright (c) 2024 European Molecular Biology Laboratory
5
6
 
6
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
8
  """
8
9
 
10
+ import warnings
9
11
  from typing import Tuple, Dict
10
12
 
11
13
  import numpy as np
@@ -14,6 +16,7 @@ from ..types import NDArray
14
16
  from ..backends import backend as be
15
17
  from .compose import ComposableFilter
16
18
  from ..matching_utils import centered
19
+ from ..parser import XMLParser, StarParser
17
20
  from ..rotations import euler_to_rotationmatrix
18
21
  from ._utils import (
19
22
  centered_grid,
@@ -37,9 +40,9 @@ class Wedge(ComposableFilter):
37
40
  shape : tuple of int
38
41
  The shape of the reconstruction volume.
39
42
  tilt_axis : int
40
- Axis the plane is tilted over.
43
+ Axis the plane is tilted over, defaults to 0 (x).
41
44
  opening_axis : int
42
- The axis around which the volume is opened.
45
+ The projection axis, defaults to 2 (z).
43
46
  angles : tuple of float
44
47
  The tilt angles.
45
48
  weights : tuple of float
@@ -51,17 +54,24 @@ class Wedge(ComposableFilter):
51
54
 
52
55
  Returns
53
56
  -------
54
- Dict
55
- A dictionary containing weighted wedges and related information.
57
+ dict
58
+ data: BackendArray
59
+ The filter mask.
60
+ shape: tuple of ints
61
+ The requested filter shape
62
+ return_real_fourier: bool
63
+ Whether data is compliant with rfftn.
64
+ is_multiplicative_filter: bool
65
+ Whether the filter is multiplicative in Fourier space.
56
66
  """
57
67
 
58
68
  def __init__(
59
69
  self,
60
70
  shape: Tuple[int],
61
- tilt_axis: int,
62
- opening_axis: int,
63
71
  angles: Tuple[float],
64
72
  weights: Tuple[float],
73
+ tilt_axis: int = 0,
74
+ opening_axis: int = 2,
65
75
  weight_type: str = None,
66
76
  frequency_cutoff: float = 0.5,
67
77
  ):
@@ -75,8 +85,16 @@ class Wedge(ComposableFilter):
75
85
  @classmethod
76
86
  def from_file(cls, filename: str) -> "Wedge":
77
87
  """
78
- Generate a :py:class:`Wedge` instance by reading tilt angles and weights
79
- from a tab-separated text file.
88
+ Generate a :py:class:`Wedge` instance by reading tilt angles and weights.
89
+ Supported extensions are:
90
+
91
+ +-------+---------------------------------------------------------+
92
+ | .star | Tomostar STAR file |
93
+ +-------+---------------------------------------------------------+
94
+ | .xml | WARP/M XML file |
95
+ +-------+---------------------------------------------------------+
96
+ | .* | Tab-separated file with optional column names |
97
+ +-------+---------------------------------------------------------+
80
98
 
81
99
  Parameters
82
100
  ----------
@@ -88,8 +106,13 @@ class Wedge(ComposableFilter):
88
106
  :py:class:`Wedge`
89
107
  Class instance instance initialized with angles and weights from the file.
90
108
  """
91
- data = cls._from_text(filename)
109
+ func = _from_text
110
+ if filename.lower().endswith("xml"):
111
+ func = _from_xml
112
+ elif filename.lower().endswith("star"):
113
+ func = _from_star
92
114
 
115
+ data = func(filename)
93
116
  angles, weights = data.get("angles", None), data.get("weights", None)
94
117
  if angles is None:
95
118
  raise ValueError(f"Could not find colum angles in {filename}")
@@ -108,32 +131,6 @@ class Wedge(ComposableFilter):
108
131
  weights=np.array(weights, dtype=np.float32),
109
132
  )
110
133
 
111
- @staticmethod
112
- def _from_text(filename: str, delimiter="\t") -> Dict:
113
- """
114
- Read column data from a text file.
115
-
116
- Parameters
117
- ----------
118
- filename : str
119
- The path to the text file.
120
- delimiter : str, optional
121
- The delimiter used in the file, defaults to '\t'.
122
-
123
- Returns
124
- -------
125
- Dict
126
- A dictionary with one key for each column.
127
- """
128
- with open(filename, mode="r", encoding="utf-8") as infile:
129
- data = [x.strip() for x in infile.read().split("\n")]
130
- data = [x.split("\t") for x in data if len(x)]
131
-
132
- headers = data.pop(0)
133
- ret = {header: list(column) for header, column in zip(headers, zip(*data))}
134
-
135
- return ret
136
-
137
134
  def __call__(self, **kwargs: Dict) -> NDArray:
138
135
  func_args = vars(self).copy()
139
136
  func_args.update(kwargs)
@@ -172,10 +169,7 @@ class Wedge(ComposableFilter):
172
169
 
173
170
  return {
174
171
  "data": ret,
175
- "angles": func_args["angles"],
176
- "tilt_axis": func_args["tilt_axis"],
177
- "opening_axis": func_args["opening_axis"],
178
- "sampling_rate": func_args.get("sampling_rate", 1),
172
+ "shape": func_args["shape"],
179
173
  "is_multiplicative_filter": True,
180
174
  }
181
175
 
@@ -297,10 +291,10 @@ class WedgeReconstructed:
297
291
  ----------
298
292
  angles :tuple of float, optional
299
293
  The tilt angles, defaults to None.
300
- opening_axis : int, optional
301
- The axis around which the wedge is opened.
302
- tilt_axis : int, optional
303
- The axis along which the tilt is applied.
294
+ tilt_axis : int
295
+ Axis the plane is tilted over, defaults to 0 (x).
296
+ opening_axis : int
297
+ The projection axis, defaults to 2 (z).
304
298
  weights : tuple of float, optional
305
299
  Weights to assign to individual wedge components.
306
300
  weight_wedge : bool, optional
@@ -317,8 +311,8 @@ class WedgeReconstructed:
317
311
 
318
312
  def __init__(
319
313
  self,
320
- opening_axis: int,
321
- tilt_axis: int,
314
+ opening_axis: int = 2,
315
+ tilt_axis: int = 0,
322
316
  angles: Tuple[float] = None,
323
317
  weights: Tuple[float] = None,
324
318
  weight_wedge: bool = False,
@@ -336,7 +330,9 @@ class WedgeReconstructed:
336
330
  self.create_continuous_wedge = create_continuous_wedge
337
331
  self.frequency_cutoff = frequency_cutoff
338
332
 
339
- def __call__(self, shape: Tuple[int], **kwargs: Dict) -> Dict:
333
+ def __call__(
334
+ self, shape: Tuple[int], return_real_fourier: bool = False, **kwargs: Dict
335
+ ) -> Dict:
340
336
  """
341
337
  Generate the reconstructed wedge.
342
338
 
@@ -344,20 +340,28 @@ class WedgeReconstructed:
344
340
  ----------
345
341
  shape : tuple of int
346
342
  The shape of the reconstruction volume.
343
+ return_real_fourier : tuple of int
344
+ Return a shape compliant with rfft, i.e., omit the negative frequencies
345
+ terms resulting in a return shape (*shape[:-1], shape[-1]//2+1). Defaults
346
+ to False.
347
347
  **kwargs : Dict
348
348
  Additional keyword arguments.
349
349
 
350
350
  Returns
351
351
  -------
352
- Dict
353
- A dictionary containing the reconstructed wedge and related information.
352
+ dict
353
+ data: BackendArray
354
+ The filter mask.
355
+ shape: tuple of ints
356
+ The requested filter shape
357
+ return_real_fourier: bool
358
+ Whether data is compliant with rfftn.
359
+ is_multiplicative_filter: bool
360
+ Whether the filter is multiplicative in Fourier space.
354
361
  """
355
362
  func_args = vars(self).copy()
356
363
  func_args.update(kwargs)
357
364
 
358
- if kwargs.get("is_fourier_shape", False):
359
- print("Cannot create continuous wedge mask based on real fourier shape.")
360
-
361
365
  func = self.step_wedge
362
366
  if func_args.get("create_continuous_wedge", False):
363
367
  func = self.continuous_wedge
@@ -386,17 +390,15 @@ class WedgeReconstructed:
386
390
  ret = be.astype(be.to_backend_array(ret), be._float_dtype)
387
391
 
388
392
  ret = shift_fourier(data=ret, shape_is_real_fourier=False)
389
- if func_args.get("return_real_fourier", False):
393
+
394
+ if return_real_fourier:
390
395
  ret = crop_real_fourier(ret)
391
396
 
392
397
  return {
393
398
  "data": ret,
394
- "shape_is_real_fourier": func_args["return_real_fourier"],
395
- "shape": ret.shape,
396
- "tilt_axis": func_args["tilt_axis"],
397
- "opening_axis": func_args["opening_axis"],
399
+ "shape": shape,
400
+ "return_real_fourier": return_real_fourier,
398
401
  "is_multiplicative_filter": True,
399
- "angles": func_args["angles"],
400
402
  }
401
403
 
402
404
  @staticmethod
@@ -426,6 +428,7 @@ class WedgeReconstructed:
426
428
  NDArray
427
429
  Wedge mask.
428
430
  """
431
+ angles = np.abs(np.asarray(angles))
429
432
  aspect_ratio = shape[opening_axis] / shape[tilt_axis]
430
433
  angles = np.degrees(np.arctan(np.tan(np.radians(angles)) * aspect_ratio))
431
434
 
@@ -540,3 +543,78 @@ class WedgeReconstructed:
540
543
  wedge_volume = np.tile(wedge_volume, tile_dimensions)
541
544
 
542
545
  return wedge_volume
546
+
547
+
548
+ def _from_xml(filename: str, **kwargs) -> Dict:
549
+ """
550
+ Read tilt data from a WARP/M XML file.
551
+
552
+ Parameters
553
+ ----------
554
+ filename : str
555
+ The path to the text file.
556
+
557
+ Returns
558
+ -------
559
+ Dict
560
+ A dictionary with one key for each column.
561
+ """
562
+ data = XMLParser(filename)
563
+ return {"angles": data["Angles"], "weights": data["Dose"]}
564
+
565
+
566
+ def _from_star(filename: str, **kwargs) -> Dict:
567
+ """
568
+ Read tilt data from a tomostar STAR file.
569
+
570
+ Parameters
571
+ ----------
572
+ filename : str
573
+ The path to the text file.
574
+
575
+ Returns
576
+ -------
577
+ Dict
578
+ A dictionary with one key for each column.
579
+ """
580
+ data = StarParser(filename, delimiter=None)["data_"]
581
+ return {"angles": data["_wrpAxisAngle"], "weights": data["_wrpDose"]}
582
+
583
+
584
+ def _from_text(filename: str, **kwargs) -> Dict:
585
+ """
586
+ Read column data from a text file.
587
+
588
+ Parameters
589
+ ----------
590
+ filename : str
591
+ The path to the text file.
592
+ delimiter : str, optional
593
+ The delimiter used in the file, defaults to '\t'.
594
+
595
+ Returns
596
+ -------
597
+ Dict
598
+ A dictionary with one key for each column.
599
+ """
600
+ with open(filename, mode="r", encoding="utf-8") as infile:
601
+ data = [x.strip() for x in infile.read().split("\n")]
602
+ data = [x.split("\t") for x in data if len(x)]
603
+
604
+ if "angles" in data[0]:
605
+ headers = data.pop(0)
606
+ else:
607
+ warnings.warn(
608
+ f"Did not find a column named 'angles' in {filename}. Assuming "
609
+ "first column specifies angles."
610
+ )
611
+ if len(data[0]) != 1:
612
+ raise ValueError(
613
+ "Found more than one column without column names. Please add "
614
+ "column names to your file. If you only want to specify tilt "
615
+ "angles without column names, use a single column file."
616
+ )
617
+ headers = ("angles",)
618
+ ret = {header: list(column) for header, column in zip(headers, zip(*data))}
619
+
620
+ return ret
tme/filters/whitening.py CHANGED
@@ -1,8 +1,9 @@
1
- """ Implements class BandPassFilter to create Fourier filter representations.
1
+ """
2
+ Implements class BandPassFilter to create Fourier filter representations.
2
3
 
3
- Copyright (c) 2024 European Molecular Biology Laboratory
4
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  from typing import Tuple, Dict
@@ -14,7 +15,7 @@ from scipy.ndimage import map_coordinates
14
15
  from ..types import BackendArray
15
16
  from ..backends import backend as be
16
17
  from .compose import ComposableFilter
17
- from ._utils import fftfreqn, compute_fourier_shape
18
+ from ._utils import fftfreqn, compute_fourier_shape, shift_fourier
18
19
 
19
20
  __all__ = ["LinearWhiteningFilter"]
20
21
 
@@ -103,13 +104,6 @@ class LinearWhiteningFilter(ComposableFilter):
103
104
  shape_is_real_fourier: bool = True,
104
105
  order: int = 1,
105
106
  ) -> BackendArray:
106
- """
107
- References
108
- ----------
109
- .. [1] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
110
- R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
111
- 13375 (2023)
112
- """
113
107
  grid = fftfreqn(
114
108
  shape=shape,
115
109
  sampling_rate=0.5,
@@ -123,11 +117,13 @@ class LinearWhiteningFilter(ComposableFilter):
123
117
 
124
118
  def __call__(
125
119
  self,
120
+ shape: Tuple[int],
126
121
  data: BackendArray = None,
127
122
  data_rfft: BackendArray = None,
128
123
  n_bins: int = None,
129
124
  batch_dimension: int = None,
130
125
  order: int = 1,
126
+ return_real_fourier: bool = True,
131
127
  **kwargs: Dict,
132
128
  ) -> Dict:
133
129
  """
@@ -135,6 +131,8 @@ class LinearWhiteningFilter(ComposableFilter):
135
131
 
136
132
  Parameters
137
133
  ----------
134
+ shape : tuple of ints
135
+ Shape of the returned whitening filter.
138
136
  data : BackendArray, optional
139
137
  The input data, defaults to None.
140
138
  data_rfft : BackendArray, optional
@@ -143,49 +141,51 @@ class LinearWhiteningFilter(ComposableFilter):
143
141
  The number of bins for computing the spectrum, defaults to None.
144
142
  batch_dimension : int, optional
145
143
  Batch dimension to average over.
146
- order : int, optional
147
- Interpolation order to use.
144
+ return_real_fourier : tuple of int
145
+ Return a shape compliant with rfft, i.e., omit the negative frequencies
146
+ terms resulting in a return shape (*shape[:-1], shape[-1]//2+1)
148
147
  **kwargs : Dict
149
148
  Additional keyword arguments.
150
149
 
151
150
  Returns
152
151
  -------
153
- Dict
154
- Filter data and associated parameters.
152
+ dict
153
+ data: BackendArray
154
+ The filter mask.
155
+ shape: tuple of ints
156
+ The requested filter shape
157
+ return_real_fourier: bool
158
+ Whether data is compliant with rfftn.
159
+ is_multiplicative_filter: bool
160
+ Whether the filter is multiplicative in Fourier space.
155
161
  """
156
162
  if data_rfft is None:
157
- data_rfft = np.fft.rfftn(be.to_numpy_array(data))
163
+ data_rfft = be.rfftn(data)
158
164
 
159
165
  data_rfft = be.to_numpy_array(data_rfft)
160
-
161
166
  bins, radial_averages = self._compute_spectrum(
162
167
  data_rfft, n_bins, batch_dimension
163
168
  )
169
+ shape = tuple(int(x) for i, x in enumerate(shape) if i != batch_dimension)
164
170
 
165
- if order is None:
166
- cutoff = bins < radial_averages.size
167
- filter_mask = np.zeros(bins.shape, radial_averages.dtype)
168
- filter_mask[cutoff] = radial_averages[bins[cutoff]]
169
- else:
170
- shape = bins.shape
171
- if kwargs.get("shape", False):
172
- shape = compute_fourier_shape(
173
- shape=kwargs.get("shape"),
174
- shape_is_real_fourier=kwargs.get("shape_is_real_fourier", False),
175
- )
176
-
177
- filter_mask = self._interpolate_spectrum(
178
- spectrum=radial_averages,
171
+ shape_filter = shape
172
+ if return_real_fourier:
173
+ shape_filter = compute_fourier_shape(
179
174
  shape=shape,
180
- shape_is_real_fourier=True,
175
+ shape_is_real_fourier=False,
181
176
  )
182
177
 
183
- filter_mask = np.fft.ifftshift(
184
- filter_mask,
185
- axes=tuple(i for i in range(data_rfft.ndim - 1) if i != batch_dimension),
178
+ ret = self._interpolate_spectrum(
179
+ spectrum=radial_averages,
180
+ shape=shape_filter,
181
+ shape_is_real_fourier=return_real_fourier,
186
182
  )
187
183
 
184
+ ret = shift_fourier(data=ret, shape_is_real_fourier=return_real_fourier)
185
+
188
186
  return {
189
- "data": be.to_backend_array(filter_mask),
187
+ "data": be.to_backend_array(ret),
188
+ "shape": shape,
189
+ "return_real_fourier": return_real_fourier,
190
190
  "is_multiplicative_filter": True,
191
191
  }
tme/matching_data.py CHANGED
@@ -1,8 +1,9 @@
1
- """ Class representation of template matching data.
1
+ """
2
+ Class representation of template matching data.
2
3
 
3
- Copyright (c) 2023 European Molecular Biology Laboratory
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  import warnings
@@ -1,8 +1,9 @@
1
- """ Implements cross-correlation based template matching using different metrics.
1
+ """
2
+ Implements cross-correlation based template matching using different metrics.
2
3
 
3
- Copyright (c) 2023 European Molecular Biology Laboratory
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  import sys
@@ -19,6 +20,7 @@ from .filters import Compose
19
20
  from .backends import backend as be
20
21
  from .matching_utils import split_shape
21
22
  from .types import CallbackClass, MatchingData
23
+ from .analyzer.proxy import SharedAnalyzerProxy
22
24
  from .matching_scores import MATCHING_EXHAUSTIVE_REGISTER
23
25
  from .memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
24
26
 
@@ -242,6 +244,7 @@ def scan(
242
244
  "aggregate_axis": matching_data._batch_axis(matching_data._batch_mask),
243
245
  "n_rotations": matching_data.rotations.shape[0],
244
246
  }
247
+ callback_class_args["inversion_mapping"] = n_jobs == 1
245
248
  default_callback_args.update(callback_class_args)
246
249
 
247
250
  setup = matching_setup(
@@ -257,13 +260,17 @@ def scan(
257
260
  matching_data._free_data()
258
261
  be.free_cache()
259
262
 
260
- # Some analyzers cannot be shared across processes
261
- if not getattr(callback_class, "is_shareable", False):
262
- jobs_per_callback_class = 1
263
-
264
263
  n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
265
264
  callback_classes = [
266
- callback_class(**default_callback_args) if callback_class else None
265
+ (
266
+ SharedAnalyzerProxy(
267
+ callback_class,
268
+ default_callback_args,
269
+ shm_handler=shm_handler if n_jobs > 1 else None,
270
+ )
271
+ if callback_class
272
+ else None
273
+ )
267
274
  for _ in range(n_callback_classes)
268
275
  ]
269
276
  ret = Parallel(n_jobs=n_jobs)(
@@ -276,14 +283,9 @@ def scan(
276
283
  )
277
284
  for index, rotation in enumerate(matching_data._split_rotations_on_jobs(n_jobs))
278
285
  )
279
-
280
- # TODO: Make sure peak callers are thread safe to begin with
281
- if not getattr(callback_class, "is_shareable", False):
282
- callback_classes = ret
283
-
284
286
  callbacks = [
285
- tuple(callback._postprocess(**default_callback_args))
286
- for callback in callback_classes
287
+ callback.result(**default_callback_args)
288
+ for callback in ret[:n_callback_classes]
287
289
  if callback
288
290
  ]
289
291
  be.free_cache()
@@ -454,7 +456,6 @@ def scan_subsets(
454
456
  for index, (target_split, template_split) in enumerate(splits)
455
457
  ]
456
458
  )
457
-
458
459
  matching_data._free_data()
459
460
  if callback_class is not None:
460
461
  return callback_class.merge(results, **callback_class_args)