pytme 0.2.9.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. pytme-0.3b0.post1.data/scripts/estimate_memory_usage.py +76 -0
  2. pytme-0.3b0.post1.data/scripts/match_template.py +1098 -0
  3. {pytme-0.2.9.post1.data → pytme-0.3b0.post1.data}/scripts/postprocess.py +318 -189
  4. {pytme-0.2.9.post1.data → pytme-0.3b0.post1.data}/scripts/preprocess.py +21 -31
  5. {pytme-0.2.9.post1.data → pytme-0.3b0.post1.data}/scripts/preprocessor_gui.py +12 -12
  6. pytme-0.3b0.post1.data/scripts/pytme_runner.py +769 -0
  7. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/METADATA +21 -20
  8. pytme-0.3b0.post1.dist-info/RECORD +126 -0
  9. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/entry_points.txt +2 -1
  10. pytme-0.3b0.post1.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +76 -0
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +224 -0
  14. scripts/match_template.py +341 -378
  15. pytme-0.2.9.post1.data/scripts/match_template.py → scripts/match_template_filters.py +213 -148
  16. scripts/postprocess.py +318 -189
  17. scripts/preprocess.py +21 -31
  18. scripts/preprocessor_gui.py +12 -12
  19. scripts/pytme_runner.py +769 -0
  20. scripts/refine_matches.py +625 -0
  21. tests/preprocessing/test_frequency_filters.py +28 -14
  22. tests/test_analyzer.py +41 -36
  23. tests/test_backends.py +1 -0
  24. tests/test_matching_cli.py +109 -54
  25. tests/test_matching_data.py +5 -5
  26. tests/test_matching_exhaustive.py +1 -2
  27. tests/test_matching_optimization.py +4 -9
  28. tests/test_matching_utils.py +1 -1
  29. tests/test_orientations.py +0 -1
  30. tme/__version__.py +1 -1
  31. tme/analyzer/__init__.py +2 -0
  32. tme/analyzer/_utils.py +26 -21
  33. tme/analyzer/aggregation.py +395 -222
  34. tme/analyzer/base.py +127 -0
  35. tme/analyzer/peaks.py +189 -204
  36. tme/analyzer/proxy.py +123 -0
  37. tme/backends/__init__.py +4 -3
  38. tme/backends/_cupy_utils.py +25 -24
  39. tme/backends/_jax_utils.py +20 -18
  40. tme/backends/cupy_backend.py +13 -26
  41. tme/backends/jax_backend.py +24 -23
  42. tme/backends/matching_backend.py +4 -3
  43. tme/backends/mlx_backend.py +4 -3
  44. tme/backends/npfftw_backend.py +34 -30
  45. tme/backends/pytorch_backend.py +18 -4
  46. tme/cli.py +126 -0
  47. tme/density.py +9 -7
  48. tme/filters/__init__.py +3 -3
  49. tme/filters/_utils.py +36 -10
  50. tme/filters/bandpass.py +229 -188
  51. tme/filters/compose.py +5 -4
  52. tme/filters/ctf.py +516 -254
  53. tme/filters/reconstruction.py +91 -32
  54. tme/filters/wedge.py +196 -135
  55. tme/filters/whitening.py +37 -42
  56. tme/matching_data.py +28 -39
  57. tme/matching_exhaustive.py +31 -27
  58. tme/matching_optimization.py +5 -4
  59. tme/matching_scores.py +25 -15
  60. tme/matching_utils.py +54 -9
  61. tme/memory.py +4 -3
  62. tme/orientations.py +22 -9
  63. tme/parser.py +114 -33
  64. tme/preprocessor.py +6 -5
  65. tme/rotations.py +10 -7
  66. tme/structure.py +4 -3
  67. pytme-0.2.9.post1.data/scripts/estimate_ram_usage.py +0 -97
  68. pytme-0.2.9.post1.dist-info/RECORD +0 -119
  69. pytme-0.2.9.post1.dist-info/licenses/LICENSE +0 -153
  70. scripts/estimate_ram_usage.py +0 -97
  71. tests/data/Maps/.DS_Store +0 -0
  72. tests/data/Structures/.DS_Store +0 -0
  73. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/WHEEL +0 -0
  74. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/top_level.txt +0 -0
tme/filters/whitening.py CHANGED
@@ -1,8 +1,9 @@
1
- """ Implements class BandPassFilter to create Fourier filter representations.
1
+ """
2
+ Implements class BandPassFilter to create Fourier filter representations.
2
3
 
3
- Copyright (c) 2024 European Molecular Biology Laboratory
4
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  from typing import Tuple, Dict
@@ -14,7 +15,7 @@ from scipy.ndimage import map_coordinates
14
15
  from ..types import BackendArray
15
16
  from ..backends import backend as be
16
17
  from .compose import ComposableFilter
17
- from ._utils import fftfreqn, compute_fourier_shape
18
+ from ._utils import fftfreqn, compute_fourier_shape, shift_fourier
18
19
 
19
20
  __all__ = ["LinearWhiteningFilter"]
20
21
 
@@ -23,11 +24,6 @@ class LinearWhiteningFilter(ComposableFilter):
23
24
  """
24
25
  Compute Fourier power spectrums and perform whitening.
25
26
 
26
- Parameters
27
- ----------
28
- **kwargs : Dict, optional
29
- Additional keyword arguments.
30
-
31
27
  References
32
28
  ----------
33
29
  .. [1] de Teresa-Trueba, I.; Goetz, S. K.; Mattausch, A.; Stojanovska, F.; Zimmerli, C. E.;
@@ -38,7 +34,7 @@ class LinearWhiteningFilter(ComposableFilter):
38
34
  13375 (2023)
39
35
  """
40
36
 
41
- def __init__(self, **kwargs):
37
+ def __init__(self, *args, **kwargs):
42
38
  pass
43
39
 
44
40
  @staticmethod
@@ -103,13 +99,6 @@ class LinearWhiteningFilter(ComposableFilter):
103
99
  shape_is_real_fourier: bool = True,
104
100
  order: int = 1,
105
101
  ) -> BackendArray:
106
- """
107
- References
108
- ----------
109
- .. [1] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
110
- R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
111
- 13375 (2023)
112
- """
113
102
  grid = fftfreqn(
114
103
  shape=shape,
115
104
  sampling_rate=0.5,
@@ -123,11 +112,13 @@ class LinearWhiteningFilter(ComposableFilter):
123
112
 
124
113
  def __call__(
125
114
  self,
115
+ shape: Tuple[int],
126
116
  data: BackendArray = None,
127
117
  data_rfft: BackendArray = None,
128
118
  n_bins: int = None,
129
119
  batch_dimension: int = None,
130
120
  order: int = 1,
121
+ return_real_fourier: bool = True,
131
122
  **kwargs: Dict,
132
123
  ) -> Dict:
133
124
  """
@@ -135,6 +126,8 @@ class LinearWhiteningFilter(ComposableFilter):
135
126
 
136
127
  Parameters
137
128
  ----------
129
+ shape : tuple of ints
130
+ Shape of the returned whitening filter.
138
131
  data : BackendArray, optional
139
132
  The input data, defaults to None.
140
133
  data_rfft : BackendArray, optional
@@ -143,49 +136,51 @@ class LinearWhiteningFilter(ComposableFilter):
143
136
  The number of bins for computing the spectrum, defaults to None.
144
137
  batch_dimension : int, optional
145
138
  Batch dimension to average over.
146
- order : int, optional
147
- Interpolation order to use.
139
+ return_real_fourier : tuple of int
140
+ Return a shape compliant with rfft, i.e., omit the negative frequencies
141
+ terms resulting in a return shape (*shape[:-1], shape[-1]//2+1)
148
142
  **kwargs : Dict
149
143
  Additional keyword arguments.
150
144
 
151
145
  Returns
152
146
  -------
153
- Dict
154
- Filter data and associated parameters.
147
+ dict
148
+ data: BackendArray
149
+ The filter mask.
150
+ shape: tuple of ints
151
+ The requested filter shape
152
+ return_real_fourier: bool
153
+ Whether data is compliant with rfftn.
154
+ is_multiplicative_filter: bool
155
+ Whether the filter is multiplicative in Fourier space.
155
156
  """
156
157
  if data_rfft is None:
157
- data_rfft = np.fft.rfftn(be.to_numpy_array(data))
158
+ data_rfft = be.rfftn(data)
158
159
 
159
160
  data_rfft = be.to_numpy_array(data_rfft)
160
-
161
161
  bins, radial_averages = self._compute_spectrum(
162
162
  data_rfft, n_bins, batch_dimension
163
163
  )
164
+ shape = tuple(int(x) for i, x in enumerate(shape) if i != batch_dimension)
164
165
 
165
- if order is None:
166
- cutoff = bins < radial_averages.size
167
- filter_mask = np.zeros(bins.shape, radial_averages.dtype)
168
- filter_mask[cutoff] = radial_averages[bins[cutoff]]
169
- else:
170
- shape = bins.shape
171
- if kwargs.get("shape", False):
172
- shape = compute_fourier_shape(
173
- shape=kwargs.get("shape"),
174
- shape_is_real_fourier=kwargs.get("shape_is_real_fourier", False),
175
- )
176
-
177
- filter_mask = self._interpolate_spectrum(
178
- spectrum=radial_averages,
166
+ shape_filter = shape
167
+ if return_real_fourier:
168
+ shape_filter = compute_fourier_shape(
179
169
  shape=shape,
180
- shape_is_real_fourier=True,
170
+ shape_is_real_fourier=False,
181
171
  )
182
172
 
183
- filter_mask = np.fft.ifftshift(
184
- filter_mask,
185
- axes=tuple(i for i in range(data_rfft.ndim - 1) if i != batch_dimension),
173
+ ret = self._interpolate_spectrum(
174
+ spectrum=radial_averages,
175
+ shape=shape_filter,
176
+ shape_is_real_fourier=return_real_fourier,
186
177
  )
187
178
 
179
+ ret = shift_fourier(data=ret, shape_is_real_fourier=return_real_fourier)
180
+
188
181
  return {
189
- "data": be.to_backend_array(filter_mask),
182
+ "data": be.to_backend_array(ret),
183
+ "shape": shape,
184
+ "return_real_fourier": return_real_fourier,
190
185
  "is_multiplicative_filter": True,
191
186
  }
tme/matching_data.py CHANGED
@@ -1,8 +1,9 @@
1
- """ Class representation of template matching data.
1
+ """
2
+ Class representation of template matching data.
2
3
 
3
- Copyright (c) 2023 European Molecular Biology Laboratory
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  import warnings
@@ -174,7 +175,7 @@ class MatchingData:
174
175
  target_pad: NDArray = None,
175
176
  template_pad: NDArray = None,
176
177
  invert_target: bool = False,
177
- ) -> "MatchingData":
178
+ ) -> Tuple["MatchingData", Tuple]:
178
179
  """
179
180
  Subset class instance based on slices.
180
181
 
@@ -193,6 +194,8 @@ class MatchingData:
193
194
  -------
194
195
  :py:class:`MatchingData`
195
196
  Newly allocated subset of class instance.
197
+ Tuple
198
+ Translation offset to merge analyzers.
196
199
 
197
200
  Examples
198
201
  --------
@@ -250,8 +253,9 @@ class MatchingData:
250
253
  target_offset[mask] = [x.start for x in target_slice]
251
254
  mask = np.subtract(1, self._target_batch).astype(bool)
252
255
  template_offset = np.zeros(len(self._output_template_shape), dtype=int)
253
- template_offset[mask] = [x.start for x in template_slice]
254
- ret._translation_offset = tuple(x for x in target_offset)
256
+ template_offset[mask] = [x.start for x, b in zip(template_slice, mask) if b]
257
+
258
+ translation_offset = tuple(x for x in target_offset)
255
259
 
256
260
  ret.target_filter = self.target_filter
257
261
  ret.template_filter = self.template_filter
@@ -261,7 +265,7 @@ class MatchingData:
261
265
  template_dim=getattr(self, "_template_dim", None),
262
266
  )
263
267
 
264
- return ret
268
+ return ret, translation_offset
265
269
 
266
270
  def to_backend(self):
267
271
  """
@@ -322,11 +326,6 @@ class MatchingData:
322
326
 
323
327
  target_ndim -= len(target_dims)
324
328
  template_ndim -= len(template_dims)
325
-
326
- if target_ndim != template_ndim:
327
- raise ValueError(
328
- f"Dimension mismatch: Target ({target_ndim}) Template ({template_ndim})."
329
- )
330
329
  self._set_matching_dimension(
331
330
  target_dims=target_dims, template_dims=template_dims
332
331
  )
@@ -491,29 +490,26 @@ class MatchingData:
491
490
  def _fourier_padding(
492
491
  target_shape: Tuple[int],
493
492
  template_shape: Tuple[int],
494
- pad_fourier: bool,
493
+ pad_target: bool = False,
495
494
  batch_mask: Tuple[int] = None,
496
495
  ) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
497
- fourier_pad = template_shape
498
- fourier_shift = np.zeros_like(template_shape)
499
-
500
496
  if batch_mask is None:
501
497
  batch_mask = np.zeros_like(template_shape)
502
498
  batch_mask = np.asarray(batch_mask)
503
499
 
504
- if not pad_fourier:
505
- fourier_pad = np.ones(len(fourier_pad), dtype=int)
500
+ fourier_pad = np.ones(len(template_shape), dtype=int)
506
501
  fourier_pad = np.multiply(fourier_pad, 1 - batch_mask)
507
502
  fourier_pad = np.add(fourier_pad, batch_mask)
508
503
 
504
+ # Avoid padding batch dimensions
509
505
  pad_shape = np.maximum(target_shape, template_shape)
506
+ pad_shape = np.maximum(target_shape, np.multiply(1 - batch_mask, pad_shape))
510
507
  ret = be.compute_convolution_shapes(pad_shape, fourier_pad)
511
508
  conv_shape, fast_shape, fast_ft_shape = ret
512
509
 
513
510
  template_mod = np.mod(template_shape, 2)
514
- if not pad_fourier:
515
- fourier_shift = 1 - np.divide(template_shape, 2).astype(int)
516
- fourier_shift = np.subtract(fourier_shift, template_mod)
511
+ fourier_shift = 1 - np.divide(template_shape, 2).astype(int)
512
+ fourier_shift = np.subtract(fourier_shift, template_mod)
517
513
 
518
514
  shape_diff = np.multiply(
519
515
  np.subtract(target_shape, template_shape), 1 - batch_mask
@@ -522,34 +518,27 @@ class MatchingData:
522
518
  if np.sum(shape_mask):
523
519
  shape_shift = np.divide(shape_diff, 2)
524
520
  offset = np.mod(shape_diff, 2)
525
- if pad_fourier:
526
- offset = -np.subtract(
527
- offset,
528
- np.logical_and(np.mod(target_shape, 2) == 0, template_mod == 1),
529
- )
530
- else:
531
- warnings.warn(
532
- "Template is larger than target and padding is turned off. Consider "
533
- "swapping them or activate padding. Correcting the shift for now."
534
- )
521
+ warnings.warn(
522
+ "Template is larger than target and padding is turned off. Consider "
523
+ "swapping them or activate padding. Correcting the shift for now."
524
+ )
535
525
  shape_shift = np.multiply(np.add(shape_shift, offset), shape_mask)
536
526
  fourier_shift = np.subtract(fourier_shift, shape_shift).astype(int)
537
527
 
538
- fourier_shift = tuple(np.multiply(fourier_shift, 1 - batch_mask).astype(int))
528
+ if pad_target:
529
+ fourier_shift = np.subtract(fourier_shift, np.subtract(1, template_mod))
539
530
 
531
+ fourier_shift = tuple(np.multiply(fourier_shift, 1 - batch_mask).astype(int))
540
532
  return tuple(conv_shape), tuple(fast_shape), tuple(fast_ft_shape), fourier_shift
541
533
 
542
- def fourier_padding(
543
- self, pad_fourier: bool = False
544
- ) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
534
+ def fourier_padding(self, pad_target: bool = False) -> Tuple:
545
535
  """
546
536
  Computes efficient shape four Fourier transforms and potential associated shifts.
547
537
 
548
538
  Parameters
549
539
  ----------
550
- pad_fourier : bool, optional
551
- If true, returns the shape of the full-convolution defined as sum of target
552
- shape and template shape minus one, False by default.
540
+ pad_target : bool, optional
541
+ Whether the target has been padded to the full convolution shape.
553
542
 
554
543
  Returns
555
544
  -------
@@ -564,7 +553,7 @@ class MatchingData:
564
553
  target_shape=be.to_numpy_array(self._output_target_shape),
565
554
  template_shape=be.to_numpy_array(self._output_template_shape),
566
555
  batch_mask=be.to_numpy_array(self._batch_mask),
567
- pad_fourier=pad_fourier,
556
+ pad_target=pad_target,
568
557
  )
569
558
 
570
559
  def computation_schedule(
@@ -1,8 +1,9 @@
1
- """ Implements cross-correlation based template matching using different metrics.
1
+ """
2
+ Implements cross-correlation based template matching using different metrics.
2
3
 
3
- Copyright (c) 2023 European Molecular Biology Laboratory
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  import sys
@@ -19,6 +20,7 @@ from .filters import Compose
19
20
  from .backends import backend as be
20
21
  from .matching_utils import split_shape
21
22
  from .types import CallbackClass, MatchingData
23
+ from .analyzer.proxy import SharedAnalyzerProxy
22
24
  from .matching_scores import MATCHING_EXHAUSTIVE_REGISTER
23
25
  from .memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
24
26
 
@@ -147,7 +149,7 @@ def scan(
147
149
  n_jobs: int = 4,
148
150
  callback_class: CallbackClass = None,
149
151
  callback_class_args: Dict = {},
150
- pad_fourier: bool = True,
152
+ pad_target: bool = True,
151
153
  pad_template_filter: bool = True,
152
154
  interpolation_order: int = 3,
153
155
  jobs_per_callback_class: int = 8,
@@ -174,8 +176,8 @@ def scan(
174
176
  Analyzer class pointer to operate on computed scores.
175
177
  callback_class_args : dict, optional
176
178
  Arguments passed to the callback_class. Default is an empty dictionary.
177
- pad_fourier: bool, optional
178
- Whether to pad target and template to the full convolution shape.
179
+ pad_target: bool, optional
180
+ Whether to pad target to the full convolution shape.
179
181
  pad_template_filter: bool, optional
180
182
  Whether to pad potential template filters to the full convolution shape.
181
183
  interpolation_order : int, optional
@@ -208,17 +210,17 @@ def scan(
208
210
  >>> )
209
211
 
210
212
  """
211
- matching_data = matching_data.subset_by_slice(
213
+ matching_data, translation_offset = matching_data.subset_by_slice(
212
214
  target_slice=target_slice,
213
215
  template_slice=template_slice,
214
- target_pad=matching_data.target_padding(pad_target=pad_fourier),
216
+ target_pad=matching_data.target_padding(pad_target=pad_target),
215
217
  )
216
218
 
217
219
  matching_data.to_backend()
218
220
  template_shape = matching_data._batch_shape(
219
221
  matching_data.template.shape, matching_data._template_batch
220
222
  )
221
- conv, fwd, inv, shift = matching_data.fourier_padding(pad_fourier=False)
223
+ conv, fwd, inv, shift = matching_data.fourier_padding(pad_target=pad_target)
222
224
 
223
225
  template_filter = _setup_template_filter_apply_target_filter(
224
226
  matching_data=matching_data,
@@ -229,19 +231,20 @@ def scan(
229
231
 
230
232
  default_callback_args = {
231
233
  "shape": fwd,
232
- "offset": matching_data._translation_offset,
234
+ "offset": translation_offset,
233
235
  "fourier_shift": shift,
234
236
  "fast_shape": fwd,
235
237
  "targetshape": matching_data._output_shape,
236
238
  "templateshape": template_shape,
237
239
  "convolution_shape": conv,
238
240
  "thread_safe": n_jobs > 1,
239
- "convolution_mode": "valid" if pad_fourier else "same",
241
+ "convolution_mode": "valid" if pad_target else "same",
240
242
  "shm_handler": shm_handler,
241
243
  "only_unique_rotations": True,
242
244
  "aggregate_axis": matching_data._batch_axis(matching_data._batch_mask),
243
245
  "n_rotations": matching_data.rotations.shape[0],
244
246
  }
247
+ callback_class_args["inversion_mapping"] = n_jobs == 1
245
248
  default_callback_args.update(callback_class_args)
246
249
 
247
250
  setup = matching_setup(
@@ -257,13 +260,17 @@ def scan(
257
260
  matching_data._free_data()
258
261
  be.free_cache()
259
262
 
260
- # Some analyzers cannot be shared across processes
261
- if not getattr(callback_class, "is_shareable", False):
262
- jobs_per_callback_class = 1
263
-
264
263
  n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
265
264
  callback_classes = [
266
- callback_class(**default_callback_args) if callback_class else None
265
+ (
266
+ SharedAnalyzerProxy(
267
+ callback_class,
268
+ default_callback_args,
269
+ shm_handler=shm_handler if n_jobs > 1 else None,
270
+ )
271
+ if callback_class
272
+ else None
273
+ )
267
274
  for _ in range(n_callback_classes)
268
275
  ]
269
276
  ret = Parallel(n_jobs=n_jobs)(
@@ -276,14 +283,9 @@ def scan(
276
283
  )
277
284
  for index, rotation in enumerate(matching_data._split_rotations_on_jobs(n_jobs))
278
285
  )
279
-
280
- # TODO: Make sure peak callers are thread safe to begin with
281
- if not getattr(callback_class, "is_shareable", False):
282
- callback_classes = ret
283
-
284
286
  callbacks = [
285
- tuple(callback._postprocess(**default_callback_args))
286
- for callback in callback_classes
287
+ callback.result(**default_callback_args)
288
+ for callback in ret[:n_callback_classes]
287
289
  if callback
288
290
  ]
289
291
  be.free_cache()
@@ -423,14 +425,17 @@ def scan_subsets(
423
425
  splits = tuple(product(target_splits, template_splits))
424
426
 
425
427
  outer_jobs, inner_jobs = job_schedule
426
- if hasattr(be, "scan"):
428
+ if be._backend_name == "jax":
429
+ func = be.scan
430
+
427
431
  corr_scoring = MATCHING_EXHAUSTIVE_REGISTER.get("CORR", (None, None))[1]
428
- results = be.scan(
432
+ results = func(
429
433
  matching_data=matching_data,
430
434
  splits=splits,
431
435
  n_jobs=outer_jobs,
432
436
  rotate_mask=matching_score != corr_scoring,
433
437
  callback_class=callback_class,
438
+ callback_class_args=callback_class_args,
434
439
  )
435
440
  else:
436
441
  results = Parallel(n_jobs=outer_jobs, verbose=verbose)(
@@ -445,7 +450,7 @@ def scan_subsets(
445
450
  callback_class=callback_class,
446
451
  callback_class_args=callback_class_args,
447
452
  interpolation_order=interpolation_order,
448
- pad_fourier=pad_target_edges,
453
+ pad_target=pad_target_edges,
449
454
  gpu_index=index % outer_jobs,
450
455
  pad_template_filter=pad_template_filter,
451
456
  target_slice=target_split,
@@ -454,7 +459,6 @@ def scan_subsets(
454
459
  for index, (target_split, template_split) in enumerate(splits)
455
460
  ]
456
461
  )
457
-
458
462
  matching_data._free_data()
459
463
  if callback_class is not None:
460
464
  return callback_class.merge(results, **callback_class_args)
@@ -1,8 +1,9 @@
1
- """ Implements methods for non-exhaustive template matching.
1
+ """
2
+ Implements methods for non-exhaustive template matching.
2
3
 
3
- Copyright (c) 2023 European Molecular Biology Laboratory
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  import warnings
@@ -1308,4 +1309,4 @@ def optimize_match(
1308
1309
  result.x = np.zeros_like(result.x)
1309
1310
  translation, rotation = result.x[:ndim], result.x[ndim:]
1310
1311
  rotation_matrix = euler_to_rotationmatrix(rotation)
1311
- return translation, rotation_matrix, result.fun
1312
+ return translation, rotation_matrix, float(result.fun)
tme/matching_scores.py CHANGED
@@ -1,8 +1,9 @@
1
- """ Implements a range of cross-correlation coefficients.
1
+ """
2
+ Implements a range of cross-correlation coefficients.
2
3
 
3
- Copyright (c) 2023-2024 European Molecular Biology Laboratory
4
+ Copyright (c) 2023-2024 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  import warnings
@@ -592,20 +593,27 @@ def corr_scoring(
592
593
  **_fftargs,
593
594
  )
594
595
 
596
+ center = be.divide(be.to_backend_array(template.shape) - 1, 2)
595
597
  unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
598
+
599
+ template_rot = be.zeros(template.shape, be._float_dtype)
596
600
  for index in range(rotations.shape[0]):
601
+ # d+1, d+1 rigid transform matrix from d,d rotation matrix
597
602
  rotation = rotations[index]
598
- arr = be.fill(arr, 0)
599
- arr, _ = be.rigid_transform(
603
+ matrix = be._rigid_transform_matrix(rotation_matrix=rotation, center=center)
604
+ template_rot, _ = be.rigid_transform(
600
605
  arr=template,
601
- rotation_matrix=rotation,
602
- out=arr,
603
- use_geometric_center=True,
606
+ rotation_matrix=matrix,
607
+ out=template_rot,
604
608
  order=interpolation_order,
605
- cache=False,
609
+ cache=True,
606
610
  )
607
- arr = template_filter_func(arr, ft_temp, template_filter)
608
- norm_template(arr[unpadded_slice], template_mask, mask_sum)
611
+
612
+ template_rot = template_filter_func(template_rot, ft_temp, template_filter)
613
+ norm_template(template_rot, template_mask, mask_sum)
614
+
615
+ arr = be.fill(arr, 0)
616
+ arr[unpadded_slice] = template_rot
609
617
 
610
618
  ft_temp = rfftn(arr, ft_temp)
611
619
  ft_temp = be.multiply(ft_target, ft_temp, out=ft_temp)
@@ -728,7 +736,7 @@ def flc_scoring(
728
736
  out_mask=temp,
729
737
  use_geometric_center=True,
730
738
  order=interpolation_order,
731
- cache=False,
739
+ cache=True,
732
740
  )
733
741
 
734
742
  n_obs = be.sum(temp)
@@ -874,7 +882,7 @@ def mcc_scoring(
874
882
  out_mask=temp,
875
883
  use_geometric_center=True,
876
884
  order=interpolation_order,
877
- cache=False,
885
+ cache=True,
878
886
  )
879
887
 
880
888
  template_filter_func(template_rot, temp_ft, template_filter)
@@ -1034,7 +1042,8 @@ def flc_scoring2(
1034
1042
  out_mask=tmp_sqz,
1035
1043
  use_geometric_center=True,
1036
1044
  order=interpolation_order,
1037
- cache=False,
1045
+ cache=True,
1046
+ batched=True,
1038
1047
  )
1039
1048
  n_obs = be.sum(tmp_sqz, axis=data_axes, keepdims=True)
1040
1049
  arr_norm = template_filter_func(arr_sqz, ft_temp, template_filter)
@@ -1154,7 +1163,8 @@ def corr_scoring2(
1154
1163
  out=arr_sqz,
1155
1164
  use_geometric_center=True,
1156
1165
  order=interpolation_order,
1157
- cache=False,
1166
+ cache=True,
1167
+ batched=True,
1158
1168
  )
1159
1169
  arr_norm = template_filter_func(arr_sqz, ft_sqz, template_filter)
1160
1170
  norm_template(arr_norm[unpadded_slice], template_mask, mask_sum, axis=data_axes)
tme/matching_utils.py CHANGED
@@ -1,8 +1,9 @@
1
- """ Utility functions for template matching.
1
+ """
2
+ Utility functions for template matching.
2
3
 
3
- Copyright (c) 2023 European Molecular Biology Laboratory
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
5
 
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
7
  """
7
8
 
8
9
  import os
@@ -519,7 +520,7 @@ def apply_convolution_mode(
519
520
  elif convolution_mode == "same":
520
521
  return func(arr, s1)
521
522
  elif convolution_mode == "valid":
522
- valid_shape = [s1[i] - s2[i] + s2[i] % 2 for i in range(arr.ndim)]
523
+ valid_shape = [s1[i] - s2[i] + 1 for i in range(arr.ndim)]
523
524
  return func(arr, valid_shape)
524
525
 
525
526
 
@@ -1056,7 +1057,7 @@ def tube_mask(
1056
1057
 
1057
1058
  def scramble_phases(
1058
1059
  arr: NDArray,
1059
- noise_proportion: float = 0.5,
1060
+ noise_proportion: float = 1.0,
1060
1061
  seed: int = 42,
1061
1062
  normalize_power: bool = False,
1062
1063
  ) -> NDArray:
@@ -1068,7 +1069,7 @@ def scramble_phases(
1068
1069
  arr : NDArray
1069
1070
  Input data.
1070
1071
  noise_proportion : float, optional
1071
- Proportion of scrambled phases, 0.5 by default.
1072
+ Proportion of scrambled phases, 1.0 by default.
1072
1073
  seed : int, optional
1073
1074
  The seed for the random phase scrambling, 42 by default.
1074
1075
  normalize_power : bool, optional
@@ -1079,15 +1080,22 @@ def scramble_phases(
1079
1080
  NDArray
1080
1081
  Phase scrambled version of ``arr``.
1081
1082
  """
1083
+ from tme.filters._utils import fftfreqn
1084
+
1082
1085
  np.random.seed(seed)
1083
1086
  noise_proportion = max(min(noise_proportion, 1), 0)
1084
1087
 
1085
1088
  arr_fft = np.fft.fftn(arr)
1086
1089
  amp, ph = np.abs(arr_fft), np.angle(arr_fft)
1087
1090
 
1088
- ph_noise = np.random.permutation(ph)
1089
- ph_new = ph * (1 - noise_proportion) + ph_noise * noise_proportion
1090
- ret = np.real(np.fft.ifftn(amp * np.exp(1j * ph_new)))
1091
+ # Scrambling up to nyquist gives more uniform noise distribution
1092
+ mask = np.fft.ifftshift(
1093
+ fftfreqn(arr_fft.shape, sampling_rate=1, compute_euclidean_norm=True) <= 0.5
1094
+ )
1095
+
1096
+ ph_noise = np.random.permutation(ph[mask])
1097
+ ph[mask] = ph[mask] * (1 - noise_proportion) + ph_noise * noise_proportion
1098
+ ret = np.real(np.fft.ifftn(amp * np.exp(1j * ph)))
1091
1099
 
1092
1100
  if normalize_power:
1093
1101
  np.divide(ret - ret.min(), ret.max() - ret.min(), out=ret)
@@ -1150,3 +1158,40 @@ def compute_extraction_box(
1150
1158
  keep = be.multiply(keep, clamp_change == 0)
1151
1159
 
1152
1160
  return obs_beg_clamp, obs_end_clamp, cand_beg, cand_end, keep
1161
+
1162
+
1163
+ class TqdmParallel(Parallel):
1164
+ """
1165
+ A minimal Parallel implementation using tqdm for progress reporting.
1166
+
1167
+ Parameters:
1168
+ -----------
1169
+ tqdm_args : dict, optional
1170
+ Dictionary of arguments passed to tqdm.tqdm
1171
+ *args, **kwargs:
1172
+ Arguments to pass to joblib.Parallel
1173
+ """
1174
+
1175
+ def __init__(self, tqdm_args: Dict = {}, *args, **kwargs):
1176
+ from tqdm import tqdm
1177
+
1178
+ super().__init__(*args, **kwargs)
1179
+ self.pbar = tqdm(**tqdm_args)
1180
+
1181
+ def __call__(self, iterable, *args, **kwargs):
1182
+ self.n_tasks = len(iterable) if hasattr(iterable, "__len__") else None
1183
+ return super().__call__(iterable, *args, **kwargs)
1184
+
1185
+ def print_progress(self):
1186
+ if self.n_tasks is None:
1187
+ return super().print_progress()
1188
+
1189
+ if self.n_tasks != self.pbar.total:
1190
+ self.pbar.total = self.n_tasks
1191
+ self.pbar.refresh()
1192
+
1193
+ self.pbar.n = self.n_completed_tasks
1194
+ self.pbar.refresh()
1195
+
1196
+ if self.n_completed_tasks >= self.n_tasks:
1197
+ self.pbar.close()