pytme 0.3b0__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/estimate_memory_usage.py +1 -5
  2. {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/match_template.py +163 -201
  3. {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/postprocess.py +48 -39
  4. {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/preprocess.py +10 -23
  5. {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/preprocessor_gui.py +3 -4
  6. pytme-0.3b0.post1.data/scripts/pytme_runner.py +769 -0
  7. {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/METADATA +14 -14
  8. {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/RECORD +54 -50
  9. {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/entry_points.txt +1 -0
  10. pytme-0.3b0.post1.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +1 -5
  12. scripts/eval.py +93 -0
  13. scripts/match_template.py +163 -201
  14. scripts/match_template_filters.py +1200 -0
  15. scripts/postprocess.py +48 -39
  16. scripts/preprocess.py +10 -23
  17. scripts/preprocessor_gui.py +3 -4
  18. scripts/pytme_runner.py +769 -0
  19. scripts/refine_matches.py +0 -1
  20. tests/preprocessing/test_frequency_filters.py +19 -10
  21. tests/test_analyzer.py +122 -122
  22. tests/test_backends.py +1 -0
  23. tests/test_matching_cli.py +30 -30
  24. tests/test_matching_data.py +5 -5
  25. tests/test_matching_utils.py +1 -1
  26. tme/__version__.py +1 -1
  27. tme/analyzer/__init__.py +1 -1
  28. tme/analyzer/_utils.py +1 -4
  29. tme/analyzer/aggregation.py +15 -6
  30. tme/analyzer/base.py +25 -36
  31. tme/analyzer/peaks.py +39 -113
  32. tme/analyzer/proxy.py +1 -0
  33. tme/backends/_jax_utils.py +16 -15
  34. tme/backends/cupy_backend.py +9 -13
  35. tme/backends/jax_backend.py +19 -16
  36. tme/backends/npfftw_backend.py +27 -25
  37. tme/backends/pytorch_backend.py +4 -0
  38. tme/density.py +5 -4
  39. tme/filters/__init__.py +2 -2
  40. tme/filters/_utils.py +32 -7
  41. tme/filters/bandpass.py +225 -186
  42. tme/filters/ctf.py +117 -67
  43. tme/filters/reconstruction.py +38 -9
  44. tme/filters/wedge.py +88 -105
  45. tme/filters/whitening.py +1 -6
  46. tme/matching_data.py +24 -36
  47. tme/matching_exhaustive.py +14 -11
  48. tme/matching_scores.py +21 -12
  49. tme/matching_utils.py +13 -6
  50. tme/orientations.py +13 -3
  51. tme/parser.py +109 -29
  52. tme/preprocessor.py +2 -2
  53. pytme-0.3b0.dist-info/licenses/LICENSE +0 -153
  54. {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/WHEEL +0 -0
  55. {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/top_level.txt +0 -0
tme/filters/wedge.py CHANGED
@@ -7,8 +7,8 @@ Copyright (c) 2024 European Molecular Biology Laboratory
7
7
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
8
8
  """
9
9
 
10
- import warnings
11
10
  from typing import Tuple, Dict
11
+ from dataclasses import dataclass
12
12
 
13
13
  import numpy as np
14
14
 
@@ -16,8 +16,8 @@ from ..types import NDArray
16
16
  from ..backends import backend as be
17
17
  from .compose import ComposableFilter
18
18
  from ..matching_utils import centered
19
- from ..parser import XMLParser, StarParser
20
19
  from ..rotations import euler_to_rotationmatrix
20
+ from ..parser import XMLParser, StarParser, MDOCParser
21
21
  from ._utils import (
22
22
  centered_grid,
23
23
  frequency_grid_at_angle,
@@ -31,56 +31,28 @@ from ._utils import (
31
31
  __all__ = ["Wedge", "WedgeReconstructed"]
32
32
 
33
33
 
34
+ @dataclass
34
35
  class Wedge(ComposableFilter):
35
36
  """
36
37
  Generate wedge mask for tomographic data.
37
-
38
- Parameters
39
- ----------
40
- shape : tuple of int
41
- The shape of the reconstruction volume.
42
- tilt_axis : int
43
- Axis the plane is tilted over, defaults to 0 (x).
44
- opening_axis : int
45
- The projection axis, defaults to 2 (z).
46
- angles : tuple of float
47
- The tilt angles.
48
- weights : tuple of float
49
- The weights corresponding to each tilt angle.
50
- weight_type : str, optional
51
- The type of weighting to apply, defaults to None.
52
- frequency_cutoff : float, optional
53
- Frequency cutoff for created mask. Nyquist 0.5 by default.
54
-
55
- Returns
56
- -------
57
- dict
58
- data: BackendArray
59
- The filter mask.
60
- shape: tuple of ints
61
- The requested filter shape
62
- return_real_fourier: bool
63
- Whether data is compliant with rfftn.
64
- is_multiplicative_filter: bool
65
- Whether the filter is multiplicative in Fourier space.
66
38
  """
67
39
 
68
- def __init__(
69
- self,
70
- shape: Tuple[int],
71
- angles: Tuple[float],
72
- weights: Tuple[float],
73
- tilt_axis: int = 0,
74
- opening_axis: int = 2,
75
- weight_type: str = None,
76
- frequency_cutoff: float = 0.5,
77
- ):
78
- self.shape = shape
79
- self.tilt_axis = tilt_axis
80
- self.opening_axis = opening_axis
81
- self.angles = angles
82
- self.weights = weights
83
- self.frequency_cutoff = frequency_cutoff
40
+ #: The shape of the reconstruction volume.
41
+ shape: Tuple[int] = None
42
+ #: The tilt angles.
43
+ angles: Tuple[float] = None
44
+ #: The weights corresponding to each tilt angle.
45
+ weights: Tuple[float] = None
46
+ #: Axis the plane is tilted over, defaults to 0 (x).
47
+ tilt_axis: int = 0
48
+ #: The projection axis, defaults to 2 (z).
49
+ opening_axis: int = 2
50
+ #: The type of weighting to apply, defaults to None.
51
+ weight_type: str = None
52
+ #: Frequency cutoff for created mask. Nyquist 0.5 by default.
53
+ frequency_cutoff: float = 0.5
54
+ #: The sampling rate, defaults to 1 Ångstrom / voxel.
55
+ sampling_rate: Tuple[float] = 1
84
56
 
85
57
  @classmethod
86
58
  def from_file(cls, filename: str) -> "Wedge":
@@ -93,6 +65,8 @@ class Wedge(ComposableFilter):
93
65
  +-------+---------------------------------------------------------+
94
66
  | .xml | WARP/M XML file |
95
67
  +-------+---------------------------------------------------------+
68
+ | .mdoc | SerialEM file |
69
+ +-------+---------------------------------------------------------+
96
70
  | .* | Tab-separated file with optional column names |
97
71
  +-------+---------------------------------------------------------+
98
72
 
@@ -111,6 +85,8 @@ class Wedge(ComposableFilter):
111
85
  func = _from_xml
112
86
  elif filename.lower().endswith("star"):
113
87
  func = _from_star
88
+ elif filename.lower().endswith("mdoc"):
89
+ func = _from_mdoc
114
90
 
115
91
  data = func(filename)
116
92
  angles, weights = data.get("angles", None), data.get("weights", None)
@@ -132,6 +108,9 @@ class Wedge(ComposableFilter):
132
108
  )
133
109
 
134
110
  def __call__(self, **kwargs: Dict) -> NDArray:
111
+ """
112
+ Returns a Wedge stack of chosen parameters with DC component in the center.
113
+ """
135
114
  func_args = vars(self).copy()
136
115
  func_args.update(kwargs)
137
116
 
@@ -170,6 +149,7 @@ class Wedge(ComposableFilter):
170
149
  return {
171
150
  "data": ret,
172
151
  "shape": func_args["shape"],
152
+ "return_real_fourier": func_args.get("return_real_fourier", False),
173
153
  "is_multiplicative_filter": True,
174
154
  }
175
155
 
@@ -196,7 +176,12 @@ class Wedge(ComposableFilter):
196
176
  return wedges
197
177
 
198
178
  def weight_relion(
199
- self, shape: Tuple[int], opening_axis: int, tilt_axis: int, **kwargs
179
+ self,
180
+ shape: Tuple[int],
181
+ opening_axis: int,
182
+ tilt_axis: int,
183
+ sampling_rate: float = 1.0,
184
+ **kwargs,
200
185
  ) -> NDArray:
201
186
  """
202
187
  Generate weighted wedges based on the RELION 1.4 formalism, weighting each
@@ -211,7 +196,6 @@ class Wedge(ComposableFilter):
211
196
  tilt_shape = compute_tilt_shape(
212
197
  shape=shape, opening_axis=opening_axis, reduce_dim=True
213
198
  )
214
-
215
199
  wedges = np.zeros((len(self.angles), *tilt_shape))
216
200
  for index, angle in enumerate(self.angles):
217
201
  frequency_grid = frequency_grid_at_angle(
@@ -219,7 +203,7 @@ class Wedge(ComposableFilter):
219
203
  opening_axis=opening_axis,
220
204
  tilt_axis=tilt_axis,
221
205
  angle=angle,
222
- sampling_rate=1,
206
+ sampling_rate=sampling_rate,
223
207
  )
224
208
  sigma = np.sqrt(self.weights[index] * 4 / (8 * np.pi**2))
225
209
  sigma = -2 * np.pi**2 * sigma**2
@@ -239,6 +223,7 @@ class Wedge(ComposableFilter):
239
223
  amplitude: float = 0.245,
240
224
  power: float = -1.665,
241
225
  offset: float = 2.81,
226
+ sampling_rate: float = 1.0,
242
227
  **kwargs,
243
228
  ) -> NDArray:
244
229
  """
@@ -264,7 +249,7 @@ class Wedge(ComposableFilter):
264
249
  opening_axis=opening_axis,
265
250
  tilt_axis=tilt_axis,
266
251
  angle=angle,
267
- sampling_rate=1,
252
+ sampling_rate=sampling_rate,
268
253
  )
269
254
 
270
255
  with np.errstate(divide="ignore"):
@@ -283,55 +268,35 @@ class Wedge(ComposableFilter):
283
268
  return wedges
284
269
 
285
270
 
271
+ @dataclass
286
272
  class WedgeReconstructed:
287
273
  """
288
274
  Initialize :py:class:`WedgeReconstructed`.
289
-
290
- Parameters
291
- ----------
292
- angles :tuple of float, optional
293
- The tilt angles, defaults to None.
294
- tilt_axis : int
295
- Axis the plane is tilted over, defaults to 0 (x).
296
- opening_axis : int
297
- The projection axis, defaults to 2 (z).
298
- weights : tuple of float, optional
299
- Weights to assign to individual wedge components.
300
- weight_wedge : bool, optional
301
- Whether individual wedge components should be weighted. If True and weights
302
- is None, uses the cosine of the angle otherwise weights.
303
- create_continuous_wedge: bool, optional
304
- Whether to create a continous wedge or a per-component wedge. Weights are only
305
- considered for non-continuous wedges.
306
- frequency_cutoff : float, optional
307
- Filter window applied during reconstruction.
308
- **kwargs : Dict
309
- Additional keyword arguments.
310
275
  """
311
276
 
312
- def __init__(
313
- self,
314
- opening_axis: int = 2,
315
- tilt_axis: int = 0,
316
- angles: Tuple[float] = None,
317
- weights: Tuple[float] = None,
318
- weight_wedge: bool = False,
319
- create_continuous_wedge: bool = False,
320
- frequency_cutoff: float = 0.5,
321
- reconstruction_filter: str = None,
322
- **kwargs: Dict,
323
- ):
324
- self.angles = angles
325
- self.opening_axis = opening_axis
326
- self.tilt_axis = tilt_axis
327
- self.weights = weights
328
- self.weight_wedge = weight_wedge
329
- self.reconstruction_filter = reconstruction_filter
330
- self.create_continuous_wedge = create_continuous_wedge
331
- self.frequency_cutoff = frequency_cutoff
277
+ #: The tilt angles, defaults to None.
278
+ angles: Tuple[float] = None
279
+ #: Weights to assign to individual wedge components. Not considered for continuous wedge
280
+ weights: Tuple[float] = None
281
+ #: Whether individual wedge components should be weighted.
282
+ weight_wedge: bool = False
283
+ #: Whether to create a continous wedge or a per-component wedge.
284
+ create_continuous_wedge: bool = False
285
+ #: Frequency cutoff of filter
286
+ frequency_cutoff: float = 0.5
287
+ #: Axis the plane is tilted over, defaults to 0 (x).
288
+ tilt_axis: int = 0
289
+ #: The projection axis, defaults to 2 (z).
290
+ opening_axis: int = 2
291
+ #: Filter window applied during reconstruction.
292
+ reconstruction_filter: str = None
293
+
294
+ def __post_init__(self):
295
+ if self.create_continuous_wedge:
296
+ self.angles = (min(self.angles), max(self.angles))
332
297
 
333
298
  def __call__(
334
- self, shape: Tuple[int], return_real_fourier: bool = False, **kwargs: Dict
299
+ self, shape: Tuple[int], return_real_fourier: bool = False, **kwargs
335
300
  ) -> Dict:
336
301
  """
337
302
  Generate the reconstructed wedge.
@@ -341,10 +306,8 @@ class WedgeReconstructed:
341
306
  shape : tuple of int
342
307
  The shape of the reconstruction volume.
343
308
  return_real_fourier : tuple of int
344
- Return a shape compliant with rfft, i.e., omit the negative frequencies
345
- terms resulting in a return shape (*shape[:-1], shape[-1]//2+1). Defaults
346
- to False.
347
- **kwargs : Dict
309
+ Return a shape compliant with rfftn. Defaults to False.
310
+ **kwargs : dict
348
311
  Additional keyword arguments.
349
312
 
350
313
  Returns
@@ -373,7 +336,6 @@ class WedgeReconstructed:
373
336
  )
374
337
 
375
338
  ret = func(shape=shape, **func_args)
376
-
377
339
  frequency_cutoff = func_args.get("frequency_cutoff", None)
378
340
  if frequency_cutoff is not None:
379
341
  frequency_mask = fftfreqn(
@@ -565,7 +527,31 @@ def _from_xml(filename: str, **kwargs) -> Dict:
565
527
 
566
528
  def _from_star(filename: str, **kwargs) -> Dict:
567
529
  """
568
- Read tilt data from a tomostar STAR file.
530
+ Read tilt data from a STAR file.
531
+
532
+ Parameters
533
+ ----------
534
+ filename : str
535
+ The path to the text file.
536
+
537
+ Returns
538
+ -------
539
+ Dict
540
+ A dictionary with one key for each column.
541
+ """
542
+ data = StarParser(filename, delimiter=None)
543
+ if "data_stopgap_wedgelist" in data:
544
+ angles = data["data_stopgap_wedgelist"]["_tilt_angle"]
545
+ weights = data["data_stopgap_wedgelist"]["_exposure"]
546
+ else:
547
+ angles = data["data_"]["_wrpAxisAngle"]
548
+ weights = data["data_"]["_wrpDose"]
549
+ return {"angles": angles, "weights": weights}
550
+
551
+
552
+ def _from_mdoc(filename: str, **kwargs) -> Dict:
553
+ """
554
+ Read tilt data from a SerialEM MDOC file.
569
555
 
570
556
  Parameters
571
557
  ----------
@@ -577,8 +563,9 @@ def _from_star(filename: str, **kwargs) -> Dict:
577
563
  Dict
578
564
  A dictionary with one key for each column.
579
565
  """
580
- data = StarParser(filename, delimiter=None)["data_"]
581
- return {"angles": data["_wrpAxisAngle"], "weights": data["_wrpDose"]}
566
+ data = MDOCParser(filename)
567
+ cumulative_exposure = np.multiply(np.add(1, data["ZValue"]), data["ExposureDose"])
568
+ return {"angles": data["TiltAngle"], "weights": cumulative_exposure}
582
569
 
583
570
 
584
571
  def _from_text(filename: str, **kwargs) -> Dict:
@@ -604,10 +591,6 @@ def _from_text(filename: str, **kwargs) -> Dict:
604
591
  if "angles" in data[0]:
605
592
  headers = data.pop(0)
606
593
  else:
607
- warnings.warn(
608
- f"Did not find a column named 'angles' in {filename}. Assuming "
609
- "first column specifies angles."
610
- )
611
594
  if len(data[0]) != 1:
612
595
  raise ValueError(
613
596
  "Found more than one column without column names. Please add "
tme/filters/whitening.py CHANGED
@@ -24,11 +24,6 @@ class LinearWhiteningFilter(ComposableFilter):
24
24
  """
25
25
  Compute Fourier power spectrums and perform whitening.
26
26
 
27
- Parameters
28
- ----------
29
- **kwargs : Dict, optional
30
- Additional keyword arguments.
31
-
32
27
  References
33
28
  ----------
34
29
  .. [1] de Teresa-Trueba, I.; Goetz, S. K.; Mattausch, A.; Stojanovska, F.; Zimmerli, C. E.;
@@ -39,7 +34,7 @@ class LinearWhiteningFilter(ComposableFilter):
39
34
  13375 (2023)
40
35
  """
41
36
 
42
- def __init__(self, **kwargs):
37
+ def __init__(self, *args, **kwargs):
43
38
  pass
44
39
 
45
40
  @staticmethod
tme/matching_data.py CHANGED
@@ -175,7 +175,7 @@ class MatchingData:
175
175
  target_pad: NDArray = None,
176
176
  template_pad: NDArray = None,
177
177
  invert_target: bool = False,
178
- ) -> "MatchingData":
178
+ ) -> Tuple["MatchingData", Tuple]:
179
179
  """
180
180
  Subset class instance based on slices.
181
181
 
@@ -194,6 +194,8 @@ class MatchingData:
194
194
  -------
195
195
  :py:class:`MatchingData`
196
196
  Newly allocated subset of class instance.
197
+ Tuple
198
+ Translation offset to merge analyzers.
197
199
 
198
200
  Examples
199
201
  --------
@@ -251,8 +253,9 @@ class MatchingData:
251
253
  target_offset[mask] = [x.start for x in target_slice]
252
254
  mask = np.subtract(1, self._target_batch).astype(bool)
253
255
  template_offset = np.zeros(len(self._output_template_shape), dtype=int)
254
- template_offset[mask] = [x.start for x in template_slice]
255
- ret._translation_offset = tuple(x for x in target_offset)
256
+ template_offset[mask] = [x.start for x, b in zip(template_slice, mask) if b]
257
+
258
+ translation_offset = tuple(x for x in target_offset)
256
259
 
257
260
  ret.target_filter = self.target_filter
258
261
  ret.template_filter = self.template_filter
@@ -262,7 +265,7 @@ class MatchingData:
262
265
  template_dim=getattr(self, "_template_dim", None),
263
266
  )
264
267
 
265
- return ret
268
+ return ret, translation_offset
266
269
 
267
270
  def to_backend(self):
268
271
  """
@@ -323,11 +326,6 @@ class MatchingData:
323
326
 
324
327
  target_ndim -= len(target_dims)
325
328
  template_ndim -= len(template_dims)
326
-
327
- if target_ndim != template_ndim:
328
- raise ValueError(
329
- f"Dimension mismatch: Target ({target_ndim}) Template ({template_ndim})."
330
- )
331
329
  self._set_matching_dimension(
332
330
  target_dims=target_dims, template_dims=template_dims
333
331
  )
@@ -492,29 +490,26 @@ class MatchingData:
492
490
  def _fourier_padding(
493
491
  target_shape: Tuple[int],
494
492
  template_shape: Tuple[int],
495
- pad_fourier: bool,
493
+ pad_target: bool = False,
496
494
  batch_mask: Tuple[int] = None,
497
495
  ) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
498
- fourier_pad = template_shape
499
- fourier_shift = np.zeros_like(template_shape)
500
-
501
496
  if batch_mask is None:
502
497
  batch_mask = np.zeros_like(template_shape)
503
498
  batch_mask = np.asarray(batch_mask)
504
499
 
505
- if not pad_fourier:
506
- fourier_pad = np.ones(len(fourier_pad), dtype=int)
500
+ fourier_pad = np.ones(len(template_shape), dtype=int)
507
501
  fourier_pad = np.multiply(fourier_pad, 1 - batch_mask)
508
502
  fourier_pad = np.add(fourier_pad, batch_mask)
509
503
 
504
+ # Avoid padding batch dimensions
510
505
  pad_shape = np.maximum(target_shape, template_shape)
506
+ pad_shape = np.maximum(target_shape, np.multiply(1 - batch_mask, pad_shape))
511
507
  ret = be.compute_convolution_shapes(pad_shape, fourier_pad)
512
508
  conv_shape, fast_shape, fast_ft_shape = ret
513
509
 
514
510
  template_mod = np.mod(template_shape, 2)
515
- if not pad_fourier:
516
- fourier_shift = 1 - np.divide(template_shape, 2).astype(int)
517
- fourier_shift = np.subtract(fourier_shift, template_mod)
511
+ fourier_shift = 1 - np.divide(template_shape, 2).astype(int)
512
+ fourier_shift = np.subtract(fourier_shift, template_mod)
518
513
 
519
514
  shape_diff = np.multiply(
520
515
  np.subtract(target_shape, template_shape), 1 - batch_mask
@@ -523,34 +518,27 @@ class MatchingData:
523
518
  if np.sum(shape_mask):
524
519
  shape_shift = np.divide(shape_diff, 2)
525
520
  offset = np.mod(shape_diff, 2)
526
- if pad_fourier:
527
- offset = -np.subtract(
528
- offset,
529
- np.logical_and(np.mod(target_shape, 2) == 0, template_mod == 1),
530
- )
531
- else:
532
- warnings.warn(
533
- "Template is larger than target and padding is turned off. Consider "
534
- "swapping them or activate padding. Correcting the shift for now."
535
- )
521
+ warnings.warn(
522
+ "Template is larger than target and padding is turned off. Consider "
523
+ "swapping them or activate padding. Correcting the shift for now."
524
+ )
536
525
  shape_shift = np.multiply(np.add(shape_shift, offset), shape_mask)
537
526
  fourier_shift = np.subtract(fourier_shift, shape_shift).astype(int)
538
527
 
539
- fourier_shift = tuple(np.multiply(fourier_shift, 1 - batch_mask).astype(int))
528
+ if pad_target:
529
+ fourier_shift = np.subtract(fourier_shift, np.subtract(1, template_mod))
540
530
 
531
+ fourier_shift = tuple(np.multiply(fourier_shift, 1 - batch_mask).astype(int))
541
532
  return tuple(conv_shape), tuple(fast_shape), tuple(fast_ft_shape), fourier_shift
542
533
 
543
- def fourier_padding(
544
- self, pad_fourier: bool = False
545
- ) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
534
+ def fourier_padding(self, pad_target: bool = False) -> Tuple:
546
535
  """
547
536
  Computes efficient shape four Fourier transforms and potential associated shifts.
548
537
 
549
538
  Parameters
550
539
  ----------
551
- pad_fourier : bool, optional
552
- If true, returns the shape of the full-convolution defined as sum of target
553
- shape and template shape minus one, False by default.
540
+ pad_target : bool, optional
541
+ Whether the target has been padded to the full convolution shape.
554
542
 
555
543
  Returns
556
544
  -------
@@ -565,7 +553,7 @@ class MatchingData:
565
553
  target_shape=be.to_numpy_array(self._output_target_shape),
566
554
  template_shape=be.to_numpy_array(self._output_template_shape),
567
555
  batch_mask=be.to_numpy_array(self._batch_mask),
568
- pad_fourier=pad_fourier,
556
+ pad_target=pad_target,
569
557
  )
570
558
 
571
559
  def computation_schedule(
@@ -149,7 +149,7 @@ def scan(
149
149
  n_jobs: int = 4,
150
150
  callback_class: CallbackClass = None,
151
151
  callback_class_args: Dict = {},
152
- pad_fourier: bool = True,
152
+ pad_target: bool = True,
153
153
  pad_template_filter: bool = True,
154
154
  interpolation_order: int = 3,
155
155
  jobs_per_callback_class: int = 8,
@@ -176,8 +176,8 @@ def scan(
176
176
  Analyzer class pointer to operate on computed scores.
177
177
  callback_class_args : dict, optional
178
178
  Arguments passed to the callback_class. Default is an empty dictionary.
179
- pad_fourier: bool, optional
180
- Whether to pad target and template to the full convolution shape.
179
+ pad_target: bool, optional
180
+ Whether to pad target to the full convolution shape.
181
181
  pad_template_filter: bool, optional
182
182
  Whether to pad potential template filters to the full convolution shape.
183
183
  interpolation_order : int, optional
@@ -210,17 +210,17 @@ def scan(
210
210
  >>> )
211
211
 
212
212
  """
213
- matching_data = matching_data.subset_by_slice(
213
+ matching_data, translation_offset = matching_data.subset_by_slice(
214
214
  target_slice=target_slice,
215
215
  template_slice=template_slice,
216
- target_pad=matching_data.target_padding(pad_target=pad_fourier),
216
+ target_pad=matching_data.target_padding(pad_target=pad_target),
217
217
  )
218
218
 
219
219
  matching_data.to_backend()
220
220
  template_shape = matching_data._batch_shape(
221
221
  matching_data.template.shape, matching_data._template_batch
222
222
  )
223
- conv, fwd, inv, shift = matching_data.fourier_padding(pad_fourier=False)
223
+ conv, fwd, inv, shift = matching_data.fourier_padding(pad_target=pad_target)
224
224
 
225
225
  template_filter = _setup_template_filter_apply_target_filter(
226
226
  matching_data=matching_data,
@@ -231,14 +231,14 @@ def scan(
231
231
 
232
232
  default_callback_args = {
233
233
  "shape": fwd,
234
- "offset": matching_data._translation_offset,
234
+ "offset": translation_offset,
235
235
  "fourier_shift": shift,
236
236
  "fast_shape": fwd,
237
237
  "targetshape": matching_data._output_shape,
238
238
  "templateshape": template_shape,
239
239
  "convolution_shape": conv,
240
240
  "thread_safe": n_jobs > 1,
241
- "convolution_mode": "valid" if pad_fourier else "same",
241
+ "convolution_mode": "valid" if pad_target else "same",
242
242
  "shm_handler": shm_handler,
243
243
  "only_unique_rotations": True,
244
244
  "aggregate_axis": matching_data._batch_axis(matching_data._batch_mask),
@@ -425,14 +425,17 @@ def scan_subsets(
425
425
  splits = tuple(product(target_splits, template_splits))
426
426
 
427
427
  outer_jobs, inner_jobs = job_schedule
428
- if hasattr(be, "scan"):
428
+ if be._backend_name == "jax":
429
+ func = be.scan
430
+
429
431
  corr_scoring = MATCHING_EXHAUSTIVE_REGISTER.get("CORR", (None, None))[1]
430
- results = be.scan(
432
+ results = func(
431
433
  matching_data=matching_data,
432
434
  splits=splits,
433
435
  n_jobs=outer_jobs,
434
436
  rotate_mask=matching_score != corr_scoring,
435
437
  callback_class=callback_class,
438
+ callback_class_args=callback_class_args,
436
439
  )
437
440
  else:
438
441
  results = Parallel(n_jobs=outer_jobs, verbose=verbose)(
@@ -447,7 +450,7 @@ def scan_subsets(
447
450
  callback_class=callback_class,
448
451
  callback_class_args=callback_class_args,
449
452
  interpolation_order=interpolation_order,
450
- pad_fourier=pad_target_edges,
453
+ pad_target=pad_target_edges,
451
454
  gpu_index=index % outer_jobs,
452
455
  pad_template_filter=pad_template_filter,
453
456
  target_slice=target_split,
tme/matching_scores.py CHANGED
@@ -593,20 +593,27 @@ def corr_scoring(
593
593
  **_fftargs,
594
594
  )
595
595
 
596
+ center = be.divide(be.to_backend_array(template.shape) - 1, 2)
596
597
  unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
598
+
599
+ template_rot = be.zeros(template.shape, be._float_dtype)
597
600
  for index in range(rotations.shape[0]):
601
+ # d+1, d+1 rigid transform matrix from d,d rotation matrix
598
602
  rotation = rotations[index]
599
- arr = be.fill(arr, 0)
600
- arr, _ = be.rigid_transform(
603
+ matrix = be._rigid_transform_matrix(rotation_matrix=rotation, center=center)
604
+ template_rot, _ = be.rigid_transform(
601
605
  arr=template,
602
- rotation_matrix=rotation,
603
- out=arr,
604
- use_geometric_center=True,
606
+ rotation_matrix=matrix,
607
+ out=template_rot,
605
608
  order=interpolation_order,
606
- cache=False,
609
+ cache=True,
607
610
  )
608
- arr = template_filter_func(arr, ft_temp, template_filter)
609
- norm_template(arr[unpadded_slice], template_mask, mask_sum)
611
+
612
+ template_rot = template_filter_func(template_rot, ft_temp, template_filter)
613
+ norm_template(template_rot, template_mask, mask_sum)
614
+
615
+ arr = be.fill(arr, 0)
616
+ arr[unpadded_slice] = template_rot
610
617
 
611
618
  ft_temp = rfftn(arr, ft_temp)
612
619
  ft_temp = be.multiply(ft_target, ft_temp, out=ft_temp)
@@ -729,7 +736,7 @@ def flc_scoring(
729
736
  out_mask=temp,
730
737
  use_geometric_center=True,
731
738
  order=interpolation_order,
732
- cache=False,
739
+ cache=True,
733
740
  )
734
741
 
735
742
  n_obs = be.sum(temp)
@@ -875,7 +882,7 @@ def mcc_scoring(
875
882
  out_mask=temp,
876
883
  use_geometric_center=True,
877
884
  order=interpolation_order,
878
- cache=False,
885
+ cache=True,
879
886
  )
880
887
 
881
888
  template_filter_func(template_rot, temp_ft, template_filter)
@@ -1035,7 +1042,8 @@ def flc_scoring2(
1035
1042
  out_mask=tmp_sqz,
1036
1043
  use_geometric_center=True,
1037
1044
  order=interpolation_order,
1038
- cache=False,
1045
+ cache=True,
1046
+ batched=True,
1039
1047
  )
1040
1048
  n_obs = be.sum(tmp_sqz, axis=data_axes, keepdims=True)
1041
1049
  arr_norm = template_filter_func(arr_sqz, ft_temp, template_filter)
@@ -1155,7 +1163,8 @@ def corr_scoring2(
1155
1163
  out=arr_sqz,
1156
1164
  use_geometric_center=True,
1157
1165
  order=interpolation_order,
1158
- cache=False,
1166
+ cache=True,
1167
+ batched=True,
1159
1168
  )
1160
1169
  arr_norm = template_filter_func(arr_sqz, ft_sqz, template_filter)
1161
1170
  norm_template(arr_norm[unpadded_slice], template_mask, mask_sum, axis=data_axes)